Uprev new code and remove C++17 workarounds. These changes have undergone internal review prior to open-sourcing. This CL was created by running the export_to_chromeos.sh export script, with manual updates to BUILD.gn. The upstream libtextclassifier code in google3 is current as of cl/351124900. Changes to the export script since the previous uprev can be seen in cl/348401862. This CL contains only refactoring changes. It is not expected to introduce any functional/feature changes. BUG=b:174953443 TEST=chromeos: (in conjunction with minor libtextclassifier ebuild TEST=changes) ML Service unit tests for tclib pass. TEST=Manual testing of Quick Answers on DUT (e.g. QA popup visible TEST=on right-clicking non-English words). Cq-Depend: 2627533 Change-Id: Ie560e5c10c0169fa55792572371adb0c49478a12 Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/third_party/libtextclassifier/+/2628507 Tested-by: Amanda Deacon <amandadeacon@chromium.org> Commit-Queue: Amanda Deacon <amandadeacon@chromium.org> Reviewed-by: Honglin Yu <honglinyu@chromium.org> Reviewed-by: Andrew Moylan <amoylan@chromium.org>
diff --git a/BUILD.gn b/BUILD.gn index 05dd10a..179d8f1 100644 --- a/BUILD.gn +++ b/BUILD.gn
@@ -32,8 +32,6 @@ sources = [ "annotator/entity-data.fbs", "annotator/experimental/experimental.fbs", - "annotator/grammar/dates/dates.fbs", - "annotator/grammar/dates/timezone-code.fbs", "annotator/model.fbs", "annotator/person_name/person_name_model.fbs", "lang_id/common/flatbuffers/embedding-network.fbs", @@ -41,7 +39,7 @@ "utils/container/bit-vector.fbs", "utils/flatbuffers/flatbuffers.fbs", "utils/codepoint-range.fbs", - "utils/grammar/next/semantics/expression.fbs", + "utils/grammar/semantics/expression.fbs", "utils/grammar/rules.fbs", "utils/i18n/language-tag.fbs", "utils/intents/intent-config.fbs", @@ -67,18 +65,11 @@ "annotator/annotator.cc", "annotator/cached-features.cc", "annotator/datetime/extractor.cc", - "annotator/datetime/parser.cc", + "annotator/datetime/regex-parser.cc", "annotator/datetime/utils.cc", "annotator/duration/duration.cc", "annotator/feature-processor.cc", "annotator/flatbuffer-utils.cc", - "annotator/grammar/dates/annotations/annotation-util.cc", - "annotator/grammar/dates/cfg-datetime-annotator.cc", - "annotator/grammar/dates/extractor.cc", - "annotator/grammar/dates/parser.cc", - "annotator/grammar/dates/utils/annotation-keys.cc", - "annotator/grammar/dates/utils/date-match.cc", - "annotator/grammar/dates/utils/date-utils.cc", "annotator/grammar/grammar-annotator.cc", "annotator/grammar/utils.cc", "annotator/model-executor.cc", @@ -135,14 +126,22 @@ "utils/container/sorted-strings-table.cc", "utils/flatbuffers/mutable.cc", "utils/flatbuffers/reflection.cc", - "utils/grammar/lexer.cc", - "utils/grammar/match.cc", - "utils/grammar/matcher.cc", + "utils/grammar/analyzer.cc", + "utils/grammar/parsing/derivation.cc", + "utils/grammar/parsing/lexer.cc", + "utils/grammar/parsing/matcher.cc", + "utils/grammar/parsing/parser.cc", + "utils/grammar/parsing/parse-tree.cc", "utils/grammar/rules-utils.cc", + "utils/grammar/semantics/composer.cc", + "utils/grammar/semantics/evaluators/arithmetic-eval.cc", + "utils/grammar/semantics/evaluators/compose-eval.cc", + "utils/grammar/semantics/evaluators/merge-values-eval.cc", "utils/grammar/utils/ir.cc", "utils/grammar/utils/rules.cc", "utils/hash/farmhash.cc", "utils/i18n/locale.cc", + "utils/i18n/locale-list.cc", "utils/math/fastexp.cc", "utils/math/softmax.cc", "utils/memory/mmap.cc",
diff --git a/annotator/annotator.cc b/annotator/annotator.cc index eb3c34b..4af4a93 100644 --- a/annotator/annotator.cc +++ b/annotator/annotator.cc
@@ -19,12 +19,14 @@ #include <cmath> #include <cstddef> #include <iterator> +#include <limits> #include <numeric> #include <string> #include <unordered_map> #include <vector> #include "annotator/collections.h" +#include "annotator/datetime/regex-parser.h" #include "annotator/flatbuffer-utils.h" #include "annotator/knowledge/knowledge-engine-types.h" #include "annotator/model_generated.h" @@ -32,7 +34,9 @@ #include "utils/base/logging.h" #include "utils/base/status.h" #include "utils/base/statusor.h" +#include "utils/calendar/calendar.h" #include "utils/checksum.h" +#include "utils/i18n/locale-list.h" #include "utils/i18n/locale.h" #include "utils/math/softmax.h" #include "utils/normalization.h" @@ -104,12 +108,8 @@ } // Returns whether the provided input is valid: -// * Valid utf8 text. // * Sane span indices. bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan& span) { - if (!context.is_valid()) { - return false; - } return (span.first >= 0 && span.first < span.second && span.second <= context.size_codepoints()); } @@ -126,37 +126,6 @@ return ints_set; } -DateAnnotationOptions ToDateAnnotationOptions( - const GrammarDatetimeModel_::AnnotationOptions* fb_annotation_options, - const std::string& reference_timezone, const int64 reference_time_ms_utc) { - DateAnnotationOptions result_annotation_options; - result_annotation_options.base_timestamp_millis = reference_time_ms_utc; - result_annotation_options.reference_timezone = reference_timezone; - if (fb_annotation_options != nullptr) { - result_annotation_options.enable_special_day_offset = - fb_annotation_options->enable_special_day_offset(); - result_annotation_options.merge_adjacent_components = - fb_annotation_options->merge_adjacent_components(); - result_annotation_options.enable_date_range = - fb_annotation_options->enable_date_range(); - result_annotation_options.include_preposition = - fb_annotation_options->include_preposition(); - if (fb_annotation_options->extra_requested_dates() != nullptr) { - for (const auto& extra_requested_date : - *fb_annotation_options->extra_requested_dates()) { - result_annotation_options.extra_requested_dates.push_back( - extra_requested_date->str()); - } - } - if (fb_annotation_options->ignored_spans() != nullptr) { - for (const auto& ignored_span : *fb_annotation_options->ignored_spans()) { - result_annotation_options.ignored_spans.push_back(ignored_span->str()); - } - } - } - return result_annotation_options; -} - } // namespace tflite::Interpreter* InterpreterManager::SelectionInterpreter() { @@ -445,25 +414,9 @@ return; } } - if (model_->grammar_datetime_model() && - model_->grammar_datetime_model()->datetime_rules()) { - cfg_datetime_parser_.reset(new dates::CfgDatetimeAnnotator( - unilib_, - /*tokenizer_options=*/ - model_->grammar_datetime_model()->grammar_tokenizer_options(), - calendarlib_, - /*datetime_rules=*/model_->grammar_datetime_model()->datetime_rules(), - model_->grammar_datetime_model()->target_classification_score(), - model_->grammar_datetime_model()->priority_score())); - if (!cfg_datetime_parser_) { - TC3_LOG(ERROR) << "Could not initialize context free grammar based " - "datetime parser."; - return; - } - } if (model_->datetime_model()) { - datetime_parser_ = DatetimeParser::Instance( + datetime_parser_ = RegexDatetimeParser::Instance( model_->datetime_model(), unilib_, calendarlib_, decompressor.get()); if (!datetime_parser_) { TC3_LOG(ERROR) << "Could not initialize datetime parser."; @@ -661,7 +614,11 @@ return true; } -void Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) { +bool Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) { + if (lang_id == nullptr) { + return false; + } + lang_id_ = lang_id; if (lang_id_ != nullptr && model_->translate_annotator_options() && model_->translate_annotator_options()->enabled()) { @@ -670,6 +627,7 @@ } else { translate_annotator_.reset(nullptr); } + return true; } bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer, @@ -853,6 +811,11 @@ CodepointSpan Annotator::SuggestSelection( const std::string& context, CodepointSpan click_indices, const SelectionOptions& options) const { + if (context.size() > std::numeric_limits<int>::max()) { + TC3_LOG(ERROR) << "Rejecting too long input: " << context.size(); + return {}; + } + CodepointSpan original_click_indices = click_indices; if (!initialized_) { TC3_LOG(ERROR) << "Not initialized"; @@ -884,6 +847,11 @@ const UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false); + if (!unilib_->IsValidUtf8(context_unicode)) { + TC3_LOG(ERROR) << "Rejecting input, invalid UTF8."; + return original_click_indices; + } + if (!IsValidSpanInput(context_unicode, click_indices)) { TC3_VLOG(1) << "Trying to run SuggestSelection with invalid input, indices: " @@ -986,9 +954,11 @@ candidates.annotated_spans[0].push_back(grammar_suggested_span); } - if (pod_ner_annotator_ != nullptr && options.use_pod_ner) { - candidates.annotated_spans[0].push_back( - pod_ner_annotator_->SuggestSelection(context_unicode, click_indices)); + AnnotatedSpan pod_ner_suggested_span; + if (pod_ner_annotator_ != nullptr && options.use_pod_ner && + pod_ner_annotator_->SuggestSelection(context_unicode, click_indices, + &pod_ner_suggested_span)) { + candidates.annotated_spans[0].push_back(pod_ner_suggested_span); } if (experimental_annotator_ != nullptr) { @@ -1696,7 +1666,7 @@ const std::string& context, const CodepointSpan& selection_indices, const ClassificationOptions& options, std::vector<ClassificationResult>* classification_results) const { - if (!datetime_parser_ && !cfg_datetime_parser_) { + if (!datetime_parser_) { return true; } @@ -1704,35 +1674,20 @@ UTF8ToUnicodeText(context, /*do_copy=*/false) .UTF8Substring(selection_indices.first, selection_indices.second); - std::vector<DatetimeParseResultSpan> datetime_spans; - - if (cfg_datetime_parser_) { - if (!(model_->grammar_datetime_model()->enabled_modes() & - ModeFlag_CLASSIFICATION)) { - return true; - } - std::vector<Locale> parsed_locales; - ParseLocales(options.locales, &parsed_locales); - cfg_datetime_parser_->Parse( - selection_text, - ToDateAnnotationOptions( - model_->grammar_datetime_model()->annotation_options(), - options.reference_timezone, options.reference_time_ms_utc), - parsed_locales, &datetime_spans); + LocaleList locale_list = LocaleList::ParseFrom(options.locales); + StatusOr<std::vector<DatetimeParseResultSpan>> result_status = + datetime_parser_->Parse(selection_text, options.reference_time_ms_utc, + options.reference_timezone, locale_list, + ModeFlag_CLASSIFICATION, + options.annotation_usecase, + /*anchor_start_end=*/true); + if (!result_status.ok()) { + TC3_LOG(ERROR) << "Error during parsing datetime."; + return false; } - if (datetime_parser_) { - if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc, - options.reference_timezone, options.locales, - ModeFlag_CLASSIFICATION, - options.annotation_usecase, - /*anchor_start_end=*/true, &datetime_spans)) { - TC3_LOG(ERROR) << "Error during parsing datetime."; - return false; - } - } - - for (const DatetimeParseResultSpan& datetime_span : datetime_spans) { + for (const DatetimeParseResultSpan& datetime_span : + result_status.ValueOrDie()) { // Only consider the result valid if the selection and extracted datetime // spans exactly match. if (CodepointSpan(datetime_span.span.first + selection_indices.first, @@ -1757,6 +1712,10 @@ std::vector<ClassificationResult> Annotator::ClassifyText( const std::string& context, const CodepointSpan& selection_indices, const ClassificationOptions& options) const { + if (context.size() > std::numeric_limits<int>::max()) { + TC3_LOG(ERROR) << "Rejecting too long input: " << context.size(); + return {}; + } if (!initialized_) { TC3_LOG(ERROR) << "Not initialized"; return {}; @@ -1784,8 +1743,15 @@ return {}; } - if (!IsValidSpanInput(UTF8ToUnicodeText(context, /*do_copy=*/false), - selection_indices)) { + const UnicodeText context_unicode = + UTF8ToUnicodeText(context, /*do_copy=*/false); + + if (!unilib_->IsValidUtf8(context_unicode)) { + TC3_LOG(ERROR) << "Rejecting input, invalid UTF8."; + return {}; + } + + if (!IsValidSpanInput(context_unicode, selection_indices)) { TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: " << selection_indices.first << " " << selection_indices.second; return {}; @@ -1859,9 +1825,6 @@ candidates.back().source = AnnotatedSpan::Source::DATETIME; } - const UnicodeText context_unicode = - UTF8ToUnicodeText(context, /*do_copy=*/false); - // Try the number annotator. // TODO(b/126579108): Propagate error status. ClassificationResult number_annotator_result; @@ -2044,25 +2007,13 @@ } const int offset = std::distance(context_unicode.begin(), line.first); - if (local_chunks.empty()) { - continue; - } - const UnicodeText line_unicode = - UTF8ToUnicodeText(line_str, /*do_copy=*/false); - std::vector<UnicodeText::const_iterator> line_codepoints = - line_unicode.Codepoints(); - line_codepoints.push_back(line_unicode.end()); for (const TokenSpan& chunk : local_chunks) { CodepointSpan codepoint_span = - TokenSpanToCodepointSpan(line_tokens, chunk); - codepoint_span = selection_feature_processor_->StripBoundaryCodepoints( - /*span_begin=*/line_codepoints[codepoint_span.first], - /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span); + selection_feature_processor_->StripBoundaryCodepoints( + line_str, TokenSpanToCodepointSpan(line_tokens, chunk)); if (model_->selection_options()->strip_unpaired_brackets()) { - codepoint_span = StripUnpairedBrackets( - /*span_begin=*/line_codepoints[codepoint_span.first], - /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span, - *unilib_); + codepoint_span = + StripUnpairedBrackets(context_unicode, codepoint_span, *unilib_); } // Skip empty spans. @@ -2161,10 +2112,6 @@ const UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false); - if (!context_unicode.is_valid()) { - return Status(StatusCode::INVALID_ARGUMENT, - "Context string isn't valid UTF8."); - } std::vector<Locale> detected_text_language_tags; if (!ParseLocales(options.detected_text_language_tags, @@ -2384,15 +2331,21 @@ std::vector<std::string> text_to_annotate; text_to_annotate.reserve(string_fragments.size()); + std::vector<FragmentMetadata> fragment_metadata; + fragment_metadata.reserve(string_fragments.size()); for (const auto& string_fragment : string_fragments) { text_to_annotate.push_back(string_fragment.text); + fragment_metadata.push_back( + {.relative_bounding_box_top = string_fragment.bounding_box_top, + .relative_bounding_box_height = string_fragment.bounding_box_height}); } // KnowledgeEngine is special, because it supports annotation of multiple // fragments at once. if (knowledge_engine_ && !knowledge_engine_ - ->ChunkMultipleSpans(text_to_annotate, options.annotation_usecase, + ->ChunkMultipleSpans(text_to_annotate, fragment_metadata, + options.annotation_usecase, options.location_context, options.permissions, options.annotate_mode, &annotation_candidates) .ok()) { @@ -2445,6 +2398,18 @@ std::vector<AnnotatedSpan> Annotator::Annotate( const std::string& context, const AnnotationOptions& options) const { + if (context.size() > std::numeric_limits<int>::max()) { + TC3_LOG(ERROR) << "Rejecting too long input."; + return {}; + } + + const UnicodeText context_unicode = + UTF8ToUnicodeText(context, /*do_copy=*/false); + if (!unilib_->IsValidUtf8(context_unicode)) { + TC3_LOG(ERROR) << "Rejecting input, invalid UTF8."; + return {}; + } + std::vector<InputFragment> string_fragments; string_fragments.push_back({.text = context}); StatusOr<Annotations> annotations = @@ -3117,31 +3082,21 @@ AnnotationUsecase annotation_usecase, bool is_serialized_entity_data_enabled, std::vector<AnnotatedSpan>* result) const { - std::vector<DatetimeParseResultSpan> datetime_spans; - if (cfg_datetime_parser_) { - if (!(model_->grammar_datetime_model()->enabled_modes() & mode)) { - return true; - } - std::vector<Locale> parsed_locales; - ParseLocales(locales, &parsed_locales); - cfg_datetime_parser_->Parse( - context_unicode.ToUTF8String(), - ToDateAnnotationOptions( - model_->grammar_datetime_model()->annotation_options(), - reference_timezone, reference_time_ms_utc), - parsed_locales, &datetime_spans); + if (!datetime_parser_) { + return true; + } + LocaleList locale_list = LocaleList::ParseFrom(locales); + StatusOr<std::vector<DatetimeParseResultSpan>> result_status = + datetime_parser_->Parse(context_unicode, reference_time_ms_utc, + reference_timezone, locale_list, mode, + annotation_usecase, + /*anchor_start_end=*/false); + if (!result_status.ok()) { + return false; } - if (datetime_parser_) { - if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc, - reference_timezone, locales, mode, - annotation_usecase, - /*anchor_start_end=*/false, &datetime_spans)) { - return false; - } - } - - for (const DatetimeParseResultSpan& datetime_span : datetime_spans) { + for (const DatetimeParseResultSpan& datetime_span : + result_status.ValueOrDie()) { AnnotatedSpan annotated_span; annotated_span.span = datetime_span.span; for (const DatetimeParseResult& parse_result : datetime_span.data) { diff --git a/annotator/annotator.h b/annotator/annotator.h index a921591..a334b49 100644 --- a/annotator/annotator.h +++ b/annotator/annotator.h
@@ -29,7 +29,6 @@ #include "annotator/duration/duration.h" #include "annotator/experimental/experimental.h" #include "annotator/feature-processor.h" -#include "annotator/grammar/dates/cfg-datetime-annotator.h" #include "annotator/grammar/grammar-annotator.h" #include "annotator/installed_app/installed-app-engine.h" #include "annotator/knowledge/knowledge-engine.h" @@ -45,6 +44,7 @@ #include "annotator/zlib-utils.h" #include "utils/base/status.h" #include "utils/base/statusor.h" +#include "utils/calendar/calendar.h" #include "utils/flatbuffers/flatbuffers.h" #include "utils/flatbuffers/mutable.h" #include "utils/i18n/locale.h" @@ -173,7 +173,7 @@ bool InitializeExperimentalAnnotators(); // Sets up the lang-id instance that should be used. - void SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id); + bool SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id); // Runs inference for given a context and current selection (i.e. index // of the first and one past last selected characters (utf8 codepoint @@ -440,8 +440,6 @@ std::unique_ptr<const FeatureProcessor> classification_feature_processor_; std::unique_ptr<const DatetimeParser> datetime_parser_; - std::unique_ptr<const dates::CfgDatetimeAnnotator> cfg_datetime_parser_; - std::unique_ptr<const GrammarAnnotator> grammar_annotator_; std::string owned_buffer_;
diff --git a/annotator/datetime/datetime-grounder.cc b/annotator/datetime/datetime-grounder.cc new file mode 100644 index 0000000..de1b6fa --- /dev/null +++ b/annotator/datetime/datetime-grounder.cc
@@ -0,0 +1,212 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "annotator/datetime/datetime-grounder.h" + +#include <vector> + +#include "annotator/datetime/datetime_generated.h" +#include "annotator/datetime/utils.h" +#include "annotator/types.h" +#include "utils/base/integral_types.h" +#include "utils/base/status.h" +#include "utils/base/status_macros.h" + +using ::libtextclassifier3::grammar::datetime::AbsoluteDateTime; +using ::libtextclassifier3::grammar::datetime::ComponentType; +using ::libtextclassifier3::grammar::datetime::Meridiem; +using ::libtextclassifier3::grammar::datetime::RelativeDateTime; +using ::libtextclassifier3::grammar::datetime::RelativeDatetimeComponent; +using ::libtextclassifier3::grammar::datetime::UngroundedDatetime; +using ::libtextclassifier3::grammar::datetime::RelativeDatetimeComponent_:: + Modifier; + +namespace libtextclassifier3 { + +namespace { + +StatusOr<DatetimeComponent::RelativeQualifier> ToRelativeQualifier( + const Modifier& modifier) { + switch (modifier) { + case Modifier::Modifier_THIS: + return DatetimeComponent::RelativeQualifier::THIS; + case Modifier::Modifier_LAST: + return DatetimeComponent::RelativeQualifier::LAST; + case Modifier::Modifier_NEXT: + return DatetimeComponent::RelativeQualifier::NEXT; + case Modifier::Modifier_NOW: + return DatetimeComponent::RelativeQualifier::NOW; + case Modifier::Modifier_TOMORROW: + return DatetimeComponent::RelativeQualifier::TOMORROW; + case Modifier::Modifier_YESTERDAY: + return DatetimeComponent::RelativeQualifier::YESTERDAY; + case Modifier::Modifier_UNSPECIFIED: + return DatetimeComponent::RelativeQualifier::UNSPECIFIED; + default: + return Status(StatusCode::INTERNAL, + "Couldn't parse the Modifier to RelativeQualifier."); + } +} + +StatusOr<DatetimeComponent::ComponentType> ToComponentType( + const grammar::datetime::ComponentType component_type) { + switch (component_type) { + case grammar::datetime::ComponentType_YEAR: + return DatetimeComponent::ComponentType::YEAR; + case grammar::datetime::ComponentType_MONTH: + return DatetimeComponent::ComponentType::MONTH; + case grammar::datetime::ComponentType_WEEK: + return DatetimeComponent::ComponentType::WEEK; + case grammar::datetime::ComponentType_DAY_OF_WEEK: + return DatetimeComponent::ComponentType::DAY_OF_WEEK; + case grammar::datetime::ComponentType_DAY_OF_MONTH: + return DatetimeComponent::ComponentType::DAY_OF_MONTH; + case grammar::datetime::ComponentType_HOUR: + return DatetimeComponent::ComponentType::HOUR; + case grammar::datetime::ComponentType_MINUTE: + return DatetimeComponent::ComponentType::MINUTE; + case grammar::datetime::ComponentType_SECOND: + return DatetimeComponent::ComponentType::SECOND; + case grammar::datetime::ComponentType_MERIDIEM: + return DatetimeComponent::ComponentType::MERIDIEM; + case grammar::datetime::ComponentType_UNSPECIFIED: + return DatetimeComponent::ComponentType::UNSPECIFIED; + default: + return Status(StatusCode::INTERNAL, + "Couldn't parse the DatetimeComponent's ComponentType from " + "grammar's datetime ComponentType."); + } +} + +void FillAbsoluteDateTimeComponents( + const grammar::datetime::AbsoluteDateTime* absolute_datetime, + DatetimeParsedData* datetime_parsed_data) { + if (absolute_datetime->year() >= 0) { + datetime_parsed_data->SetAbsoluteValue( + DatetimeComponent::ComponentType::YEAR, absolute_datetime->year()); + } + if (absolute_datetime->month() >= 0) { + datetime_parsed_data->SetAbsoluteValue( + DatetimeComponent::ComponentType::MONTH, absolute_datetime->month()); + } + if (absolute_datetime->day() >= 0) { + datetime_parsed_data->SetAbsoluteValue( + DatetimeComponent::ComponentType::DAY_OF_MONTH, + absolute_datetime->day()); + } + if (absolute_datetime->week_day() >= 0) { + datetime_parsed_data->SetAbsoluteValue( + DatetimeComponent::ComponentType::DAY_OF_WEEK, + absolute_datetime->week_day()); + } + if (absolute_datetime->hour() >= 0) { + datetime_parsed_data->SetAbsoluteValue( + DatetimeComponent::ComponentType::HOUR, absolute_datetime->hour()); + } + if (absolute_datetime->minute() >= 0) { + datetime_parsed_data->SetAbsoluteValue( + DatetimeComponent::ComponentType::MINUTE, absolute_datetime->minute()); + } + if (absolute_datetime->second() >= 0) { + datetime_parsed_data->SetAbsoluteValue( + DatetimeComponent::ComponentType::SECOND, absolute_datetime->second()); + } + if (absolute_datetime->meridiem() != grammar::datetime::Meridiem_UNKNOWN) { + datetime_parsed_data->SetAbsoluteValue( + DatetimeComponent::ComponentType::MERIDIEM, + absolute_datetime->meridiem() == grammar::datetime::Meridiem_AM ? 0 + : 1); + } + if (absolute_datetime->time_zone()) { + datetime_parsed_data->SetAbsoluteValue( + DatetimeComponent::ComponentType::ZONE_OFFSET, + absolute_datetime->time_zone()->utc_offset_mins()); + } +} + +StatusOr<DatetimeParsedData> FillRelativeDateTimeComponents( + const grammar::datetime::RelativeDateTime* relative_datetime) { + DatetimeParsedData datetime_parsed_data; + for (const RelativeDatetimeComponent* relative_component : + *relative_datetime->relative_datetime_component()) { + TC3_ASSIGN_OR_RETURN(const DatetimeComponent::ComponentType component_type, + ToComponentType(relative_component->component_type())); + datetime_parsed_data.SetRelativeCount(component_type, + relative_component->value()); + TC3_ASSIGN_OR_RETURN( + const DatetimeComponent::RelativeQualifier relative_qualifier, + ToRelativeQualifier(relative_component->modifier())); + datetime_parsed_data.SetRelativeValue(component_type, relative_qualifier); + } + if (relative_datetime->base()) { + FillAbsoluteDateTimeComponents(relative_datetime->base(), + &datetime_parsed_data); + } + return datetime_parsed_data; +} + +} // namespace + +DatetimeGrounder::DatetimeGrounder(const CalendarLib* calendarlib) + : calendarlib_(*calendarlib) {} + +StatusOr<std::vector<DatetimeParseResult>> DatetimeGrounder::Ground( + const int64 reference_time_ms_utc, const std::string& reference_timezone, + const std::string& reference_locale, + const grammar::datetime::UngroundedDatetime* ungrounded_datetime) const { + DatetimeParsedData datetime_parsed_data; + if (ungrounded_datetime->absolute_datetime()) { + FillAbsoluteDateTimeComponents(ungrounded_datetime->absolute_datetime(), + &datetime_parsed_data); + } else if (ungrounded_datetime->relative_datetime()) { + TC3_ASSIGN_OR_RETURN(datetime_parsed_data, + FillRelativeDateTimeComponents( + ungrounded_datetime->relative_datetime())); + } + std::vector<DatetimeParsedData> interpretations; + FillInterpretations(datetime_parsed_data, + calendarlib_.GetGranularity(datetime_parsed_data), + &interpretations); + std::vector<DatetimeParseResult> datetime_parse_result; + + for (const DatetimeParsedData& interpretation : interpretations) { + std::vector<DatetimeComponent> date_components; + interpretation.GetDatetimeComponents(&date_components); + DatetimeParseResult result; + // Text classifier only provides ambiguity limited to “AM/PM” which is + // encoded in the pair of DatetimeParseResult; both corresponding to the + // same date, but one corresponding to “AM” and the other one corresponding + // to “PM”. + if (!calendarlib_.InterpretParseData( + interpretation, reference_time_ms_utc, reference_timezone, + reference_locale, /*prefer_future_for_unspecified_date=*/true, + &(result.time_ms_utc), &(result.granularity))) { + return Status( + StatusCode::INTERNAL, + "Couldn't parse the UngroundedDatetime to DatetimeParseResult."); + } + + // Sort the date time units by component type. + std::sort(date_components.begin(), date_components.end(), + [](DatetimeComponent a, DatetimeComponent b) { + return a.component_type > b.component_type; + }); + result.datetime_components.swap(date_components); + datetime_parse_result.push_back(result); + } + return datetime_parse_result; +} + +} // namespace libtextclassifier3 diff --git a/annotator/datetime/datetime-grounder.h b/annotator/datetime/datetime-grounder.h new file mode 100644 index 0000000..4c8502b --- /dev/null +++ b/annotator/datetime/datetime-grounder.h
@@ -0,0 +1,46 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_DATETIME_GROUNDER_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_DATETIME_GROUNDER_H_ + +#include <vector> + +#include "annotator/datetime/datetime_generated.h" +#include "annotator/types.h" +#include "utils/base/statusor.h" +#include "utils/calendar/calendar.h" + +namespace libtextclassifier3 { + +// Utility class to resolve and complete an ungrounded datetime specification. +class DatetimeGrounder { + public: + explicit DatetimeGrounder(const CalendarLib* calendarlib); + + // Resolves ambiguities and produces concrete datetime results from an + // ungrounded datetime specification. + StatusOr<std::vector<DatetimeParseResult>> Ground( + const int64 reference_time_ms_utc, const std::string& reference_timezone, + const std::string& reference_locale, + const grammar::datetime::UngroundedDatetime* ungrounded_datetime) const; + + private: + const CalendarLib& calendarlib_; +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_DATETIME_GROUNDER_H_ diff --git a/annotator/datetime/grammar-parser.cc b/annotator/datetime/grammar-parser.cc new file mode 100644 index 0000000..c26f3d6 --- /dev/null +++ b/annotator/datetime/grammar-parser.cc
@@ -0,0 +1,88 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "annotator/datetime/grammar-parser.h" + +#include <set> +#include <unordered_set> + +#include "annotator/datetime/datetime-grounder.h" +#include "utils/grammar/analyzer.h" +#include "utils/grammar/evaluated-derivation.h" + +using ::libtextclassifier3::grammar::EvaluatedDerivation; +using ::libtextclassifier3::grammar::datetime::UngroundedDatetime; + +namespace libtextclassifier3 { + +GrammarDatetimeParser::GrammarDatetimeParser( + const grammar::Analyzer& analyzer, + const DatetimeGrounder& datetime_grounder, + const float target_classification_score, const float priority_score) + : analyzer_(analyzer), + datetime_grounder_(datetime_grounder), + target_classification_score_(target_classification_score), + priority_score_(priority_score) {} + +StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse( + const std::string& input, const int64 reference_time_ms_utc, + const std::string& reference_timezone, const LocaleList& locale_list, + ModeFlag mode, AnnotationUsecase annotation_usecase, + bool anchor_start_end) const { + return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false), + reference_time_ms_utc, reference_timezone, locale_list, mode, + annotation_usecase, anchor_start_end); +} + +StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse( + const UnicodeText& input, const int64 reference_time_ms_utc, + const std::string& reference_timezone, const LocaleList& locale_list, + ModeFlag mode, AnnotationUsecase annotation_usecase, + bool anchor_start_end) const { + std::vector<DatetimeParseResultSpan> results; + UnsafeArena arena(/*block_size=*/16 << 10); + const std::vector<EvaluatedDerivation> evaluated_derivations = + analyzer_.Parse(input, locale_list.GetLocales(), &arena).ValueOrDie(); + for (const EvaluatedDerivation& evaluated_derivation : + evaluated_derivations) { + if (evaluated_derivation.value) { + if (evaluated_derivation.value->Has<flatbuffers::Table>()) { + const UngroundedDatetime* ungrounded_datetime = + evaluated_derivation.value->Table<UngroundedDatetime>(); + const StatusOr<std::vector<DatetimeParseResult>>& + datetime_parse_results = datetime_grounder_.Ground( + reference_time_ms_utc, reference_timezone, + locale_list.GetReferenceLocale(), ungrounded_datetime); + TC3_ASSIGN_OR_RETURN( + const std::vector<DatetimeParseResult>& parse_datetime, + datetime_parse_results); + DatetimeParseResultSpan datetime_parse_result_span; + datetime_parse_result_span.target_classification_score = + target_classification_score_; + datetime_parse_result_span.priority_score = priority_score_; + datetime_parse_result_span.data.reserve(parse_datetime.size()); + datetime_parse_result_span.data.insert( + datetime_parse_result_span.data.end(), parse_datetime.begin(), + parse_datetime.end()); + datetime_parse_result_span.span = + evaluated_derivation.derivation.parse_tree->codepoint_span; + + results.emplace_back(datetime_parse_result_span); + } + } + } + return results; +} +} // namespace libtextclassifier3 diff --git a/annotator/datetime/grammar-parser.h b/annotator/datetime/grammar-parser.h new file mode 100644 index 0000000..733af16 --- /dev/null +++ b/annotator/datetime/grammar-parser.h
@@ -0,0 +1,67 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_GRAMMAR_PARSER_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_GRAMMAR_PARSER_H_ + +#include <string> +#include <vector> + +#include "annotator/datetime/datetime-grounder.h" +#include "annotator/datetime/parser.h" +#include "annotator/types.h" +#include "utils/base/statusor.h" +#include "utils/grammar/analyzer.h" +#include "utils/i18n/locale-list.h" +#include "utils/utf8/unicodetext.h" + +namespace libtextclassifier3 { + +// Parses datetime expressions in the input and resolves them to actual absolute +// time. +class GrammarDatetimeParser : public DatetimeParser { + public: + explicit GrammarDatetimeParser(const grammar::Analyzer& analyzer, + const DatetimeGrounder& datetime_grounder, + const float target_classification_score, + const float priority_score); + + // Parses the dates in 'input' and fills result. Makes sure that the results + // do not overlap. + // If 'anchor_start_end' is true the extracted results need to start at the + // beginning of 'input' and end at the end of it. + StatusOr<std::vector<DatetimeParseResultSpan>> Parse( + const std::string& input, int64 reference_time_ms_utc, + const std::string& reference_timezone, const LocaleList& locale_list, + ModeFlag mode, AnnotationUsecase annotation_usecase, + bool anchor_start_end) const override; + + // Same as above but takes UnicodeText. + StatusOr<std::vector<DatetimeParseResultSpan>> Parse( + const UnicodeText& input, int64 reference_time_ms_utc, + const std::string& reference_timezone, const LocaleList& locale_list, + ModeFlag mode, AnnotationUsecase annotation_usecase, + bool anchor_start_end) const override; + + private: + const grammar::Analyzer& analyzer_; + const DatetimeGrounder& datetime_grounder_; + const float target_classification_score_; + const float priority_score_; +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_GRAMMAR_PARSER_H_ diff --git a/annotator/datetime/parser.h b/annotator/datetime/parser.h index f987c6b..cccd092 100644 --- a/annotator/datetime/parser.h +++ b/annotator/datetime/parser.h
@@ -18,18 +18,13 @@ #include <memory> #include <string> -#include <unordered_map> -#include <unordered_set> #include <vector> -#include "annotator/datetime/extractor.h" -#include "annotator/model_generated.h" #include "annotator/types.h" -#include "utils/base/integral_types.h" -#include "utils/calendar/calendar.h" +#include "utils/base/statusor.h" +#include "utils/i18n/locale-list.h" +#include "utils/i18n/locale.h" #include "utils/utf8/unicodetext.h" -#include "utils/utf8/unilib.h" -#include "utils/zlib/tclib_zlib.h" namespace libtextclassifier3 { @@ -37,87 +32,25 @@ // time. class DatetimeParser { public: - static std::unique_ptr<DatetimeParser> Instance( - const DatetimeModel* model, const UniLib* unilib, - const CalendarLib* calendarlib, ZlibDecompressor* decompressor); + virtual ~DatetimeParser() = default; // Parses the dates in 'input' and fills result. Makes sure that the results // do not overlap. // If 'anchor_start_end' is true the extracted results need to start at the // beginning of 'input' and end at the end of it. - bool Parse(const std::string& input, int64 reference_time_ms_utc, - const std::string& reference_timezone, const std::string& locales, - ModeFlag mode, AnnotationUsecase annotation_usecase, - bool anchor_start_end, - std::vector<DatetimeParseResultSpan>* results) const; + virtual StatusOr<std::vector<DatetimeParseResultSpan>> Parse( + const std::string& input, int64 reference_time_ms_utc, + const std::string& reference_timezone, const LocaleList& locale_list, + ModeFlag mode, AnnotationUsecase annotation_usecase, + bool anchor_start_end) const = 0; // Same as above but takes UnicodeText. - bool Parse(const UnicodeText& input, int64 reference_time_ms_utc, - const std::string& reference_timezone, const std::string& locales, - ModeFlag mode, AnnotationUsecase annotation_usecase, - bool anchor_start_end, - std::vector<DatetimeParseResultSpan>* results) const; - - protected: - explicit DatetimeParser(const DatetimeModel* model, const UniLib* unilib, - const CalendarLib* calendarlib, - ZlibDecompressor* decompressor); - - // Returns a list of locale ids for given locale spec string (comma-separated - // locale names). Assigns the first parsed locale to reference_locale. - std::vector<int> ParseAndExpandLocales(const std::string& locales, - std::string* reference_locale) const; - - // Helper function that finds datetime spans, only using the rules associated - // with the given locales. - bool FindSpansUsingLocales( - const std::vector<int>& locale_ids, const UnicodeText& input, - const int64 reference_time_ms_utc, const std::string& reference_timezone, + virtual StatusOr<std::vector<DatetimeParseResultSpan>> Parse( + const UnicodeText& input, int64 reference_time_ms_utc, + const std::string& reference_timezone, const LocaleList& locale_list, ModeFlag mode, AnnotationUsecase annotation_usecase, - bool anchor_start_end, const std::string& reference_locale, - std::unordered_set<int>* executed_rules, - std::vector<DatetimeParseResultSpan>* found_spans) const; - - bool ParseWithRule(const CompiledRule& rule, const UnicodeText& input, - int64 reference_time_ms_utc, - const std::string& reference_timezone, - const std::string& reference_locale, const int locale_id, - bool anchor_start_end, - std::vector<DatetimeParseResultSpan>* result) const; - - // Converts the current match in 'matcher' into DatetimeParseResult. - bool ExtractDatetime(const CompiledRule& rule, - const UniLib::RegexMatcher& matcher, - int64 reference_time_ms_utc, - const std::string& reference_timezone, - const std::string& reference_locale, int locale_id, - std::vector<DatetimeParseResult>* results, - CodepointSpan* result_span) const; - - // Parse and extract information from current match in 'matcher'. - bool HandleParseMatch(const CompiledRule& rule, - const UniLib::RegexMatcher& matcher, - int64 reference_time_ms_utc, - const std::string& reference_timezone, - const std::string& reference_locale, int locale_id, - std::vector<DatetimeParseResultSpan>* result) const; - - private: - bool initialized_; - const UniLib& unilib_; - const CalendarLib& calendarlib_; - std::vector<CompiledRule> rules_; - std::unordered_map<int, std::vector<int>> locale_to_rules_; - std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_; - std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>> - type_and_locale_to_extractor_rule_; - std::unordered_map<std::string, int> locale_string_to_id_; - std::vector<int> default_locale_ids_; - bool use_extractors_for_locating_; - bool generate_alternative_interpretations_when_ambiguous_; - bool prefer_future_for_unspecified_date_; + bool anchor_start_end) const = 0; }; - } // namespace libtextclassifier3 #endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_ diff --git a/annotator/datetime/parser.cc b/annotator/datetime/regex-parser.cc similarity index 69% rename from annotator/datetime/parser.cc rename to annotator/datetime/regex-parser.cc index c93a0a9..7d09e79 100644 --- a/annotator/datetime/parser.cc +++ b/annotator/datetime/regex-parser.cc
@@ -13,33 +13,36 @@ // limitations under the License. // -#include "annotator/datetime/parser.h" +#include "annotator/datetime/regex-parser.h" +#include <iterator> #include <set> #include <unordered_set> #include "annotator/datetime/extractor.h" #include "annotator/datetime/utils.h" +#include "utils/base/statusor.h" #include "utils/calendar/calendar.h" #include "utils/i18n/locale.h" #include "utils/strings/split.h" #include "utils/zlib/zlib_regex.h" namespace libtextclassifier3 { -std::unique_ptr<DatetimeParser> DatetimeParser::Instance( +std::unique_ptr<DatetimeParser> RegexDatetimeParser::Instance( const DatetimeModel* model, const UniLib* unilib, const CalendarLib* calendarlib, ZlibDecompressor* decompressor) { - std::unique_ptr<DatetimeParser> result( - new DatetimeParser(model, unilib, calendarlib, decompressor)); + std::unique_ptr<RegexDatetimeParser> result( + new RegexDatetimeParser(model, unilib, calendarlib, decompressor)); if (!result->initialized_) { result.reset(); } return result; } -DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib* unilib, - const CalendarLib* calendarlib, - ZlibDecompressor* decompressor) +RegexDatetimeParser::RegexDatetimeParser(const DatetimeModel* model, + const UniLib* unilib, + const CalendarLib* calendarlib, + ZlibDecompressor* decompressor) : unilib_(*unilib), calendarlib_(*calendarlib) { initialized_ = false; @@ -112,23 +115,24 @@ initialized_ = true; } -bool DatetimeParser::Parse( +StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse( const std::string& input, const int64 reference_time_ms_utc, - const std::string& reference_timezone, const std::string& locales, - ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end, - std::vector<DatetimeParseResultSpan>* results) const { + const std::string& reference_timezone, const LocaleList& locale_list, + ModeFlag mode, AnnotationUsecase annotation_usecase, + bool anchor_start_end) const { return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false), - reference_time_ms_utc, reference_timezone, locales, mode, - annotation_usecase, anchor_start_end, results); + reference_time_ms_utc, reference_timezone, locale_list, mode, + annotation_usecase, anchor_start_end); } -bool DatetimeParser::FindSpansUsingLocales( +StatusOr<std::vector<DatetimeParseResultSpan>> +RegexDatetimeParser::FindSpansUsingLocales( const std::vector<int>& locale_ids, const UnicodeText& input, const int64 reference_time_ms_utc, const std::string& reference_timezone, ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end, const std::string& reference_locale, - std::unordered_set<int>* executed_rules, - std::vector<DatetimeParseResultSpan>* found_spans) const { + std::unordered_set<int>* executed_rules) const { + std::vector<DatetimeParseResultSpan> found_spans; for (const int locale_id : locale_ids) { auto rules_it = locale_to_rules_.find(locale_id); if (rules_it == locale_to_rules_.end()) { @@ -151,34 +155,33 @@ } executed_rules->insert(rule_id); - - if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc, - reference_timezone, reference_locale, locale_id, - anchor_start_end, found_spans)) { - return false; - } + TC3_ASSIGN_OR_RETURN( + const std::vector<DatetimeParseResultSpan>& found_spans_per_rule, + ParseWithRule(rules_[rule_id], input, reference_time_ms_utc, + reference_timezone, reference_locale, locale_id, + anchor_start_end)); + found_spans.insert(std::end(found_spans), + std::begin(found_spans_per_rule), + std::end(found_spans_per_rule)); } } - return true; + return found_spans; } -bool DatetimeParser::Parse( +StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse( const UnicodeText& input, const int64 reference_time_ms_utc, - const std::string& reference_timezone, const std::string& locales, - ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end, - std::vector<DatetimeParseResultSpan>* results) const { - std::vector<DatetimeParseResultSpan> found_spans; + const std::string& reference_timezone, const LocaleList& locale_list, + ModeFlag mode, AnnotationUsecase annotation_usecase, + bool anchor_start_end) const { std::unordered_set<int> executed_rules; - std::string reference_locale; const std::vector<int> requested_locales = - ParseAndExpandLocales(locales, &reference_locale); - if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc, - reference_timezone, mode, annotation_usecase, - anchor_start_end, reference_locale, - &executed_rules, &found_spans)) { - return false; - } - + ParseAndExpandLocales(locale_list.GetLocaleTags()); + TC3_ASSIGN_OR_RETURN( + const std::vector<DatetimeParseResultSpan>& found_spans, + FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc, + reference_timezone, mode, annotation_usecase, + anchor_start_end, locale_list.GetReferenceLocale(), + &executed_rules)); std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans; indexed_found_spans.reserve(found_spans.size()); for (int i = 0; i < found_spans.size(); i++) { @@ -199,39 +202,46 @@ } }); - found_spans.clear(); + std::vector<DatetimeParseResultSpan> results; + std::vector<DatetimeParseResultSpan> resolved_found_spans; + resolved_found_spans.reserve(indexed_found_spans.size()); for (auto& span_index_pair : indexed_found_spans) { - found_spans.push_back(span_index_pair.first); + resolved_found_spans.push_back(span_index_pair.first); } std::set<int, std::function<bool(int, int)>> chosen_indices_set( - [&found_spans](int a, int b) { - return found_spans[a].span.first < found_spans[b].span.first; + [&resolved_found_spans](int a, int b) { + return resolved_found_spans[a].span.first < + resolved_found_spans[b].span.first; }); - for (int i = 0; i < found_spans.size(); ++i) { - if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) { + for (int i = 0; i < resolved_found_spans.size(); ++i) { + if (!DoesCandidateConflict(i, resolved_found_spans, chosen_indices_set)) { chosen_indices_set.insert(i); - results->push_back(found_spans[i]); + results.push_back(resolved_found_spans[i]); } } - - return true; + return results; } -bool DatetimeParser::HandleParseMatch( - const CompiledRule& rule, const UniLib::RegexMatcher& matcher, - int64 reference_time_ms_utc, const std::string& reference_timezone, - const std::string& reference_locale, int locale_id, - std::vector<DatetimeParseResultSpan>* result) const { +StatusOr<std::vector<DatetimeParseResultSpan>> +RegexDatetimeParser::HandleParseMatch(const CompiledRule& rule, + const UniLib::RegexMatcher& matcher, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& reference_locale, + int locale_id) const { + std::vector<DatetimeParseResultSpan> results; int status = UniLib::RegexMatcher::kNoError; const int start = matcher.Start(&status); if (status != UniLib::RegexMatcher::kNoError) { - return false; + return Status(StatusCode::INTERNAL, + "Failed to gets the start offset of the last match."); } const int end = matcher.End(&status); if (status != UniLib::RegexMatcher::kNoError) { - return false; + return Status(StatusCode::INTERNAL, + "Failed to gets the end offset of the last match."); } DatetimeParseResultSpan parse_result; @@ -239,7 +249,7 @@ if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone, reference_locale, locale_id, &alternatives, &parse_result.span)) { - return false; + return Status(StatusCode::INTERNAL, "Failed to extract Datetime."); } if (!use_extractors_for_locating_) { @@ -256,49 +266,44 @@ parse_result.data.push_back(alternative); } } - result->push_back(parse_result); - return true; + results.push_back(parse_result); + return results; } -bool DatetimeParser::ParseWithRule( - const CompiledRule& rule, const UnicodeText& input, - const int64 reference_time_ms_utc, const std::string& reference_timezone, - const std::string& reference_locale, const int locale_id, - bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const { +StatusOr<std::vector<DatetimeParseResultSpan>> +RegexDatetimeParser::ParseWithRule(const CompiledRule& rule, + const UnicodeText& input, + const int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& reference_locale, + const int locale_id, + bool anchor_start_end) const { + std::vector<DatetimeParseResultSpan> results; std::unique_ptr<UniLib::RegexMatcher> matcher = rule.compiled_regex->Matcher(input); int status = UniLib::RegexMatcher::kNoError; if (anchor_start_end) { if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) { - if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc, - reference_timezone, reference_locale, locale_id, - result)) { - return false; - } + return HandleParseMatch(rule, *matcher, reference_time_ms_utc, + reference_timezone, reference_locale, locale_id); } } else { while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { - if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc, - reference_timezone, reference_locale, locale_id, - result)) { - return false; - } + TC3_ASSIGN_OR_RETURN( + const std::vector<DatetimeParseResultSpan>& pattern_occurrence, + HandleParseMatch(rule, *matcher, reference_time_ms_utc, + reference_timezone, reference_locale, locale_id)); + results.insert(std::end(results), std::begin(pattern_occurrence), + std::end(pattern_occurrence)); } } - return true; + return results; } -std::vector<int> DatetimeParser::ParseAndExpandLocales( - const std::string& locales, std::string* reference_locale) const { - std::vector<StringPiece> split_locales = strings::Split(locales, ','); - if (!split_locales.empty()) { - *reference_locale = split_locales[0].ToString(); - } else { - *reference_locale = ""; - } - +std::vector<int> RegexDatetimeParser::ParseAndExpandLocales( + const std::vector<StringPiece>& locales) const { std::vector<int> result; - for (const StringPiece& locale_str : split_locales) { + for (const StringPiece& locale_str : locales) { auto locale_it = locale_string_to_id_.find(locale_str.ToString()); if (locale_it != locale_string_to_id_.end()) { result.push_back(locale_it->second); @@ -347,14 +352,12 @@ return result; } -bool DatetimeParser::ExtractDatetime(const CompiledRule& rule, - const UniLib::RegexMatcher& matcher, - const int64 reference_time_ms_utc, - const std::string& reference_timezone, - const std::string& reference_locale, - int locale_id, - std::vector<DatetimeParseResult>* results, - CodepointSpan* result_span) const { +bool RegexDatetimeParser::ExtractDatetime( + const CompiledRule& rule, const UniLib::RegexMatcher& matcher, + const int64 reference_time_ms_utc, const std::string& reference_timezone, + const std::string& reference_locale, int locale_id, + std::vector<DatetimeParseResult>* results, + CodepointSpan* result_span) const { DatetimeParsedData parse; DatetimeExtractor extractor(rule, matcher, locale_id, &unilib_, extractor_rules_, diff --git a/annotator/datetime/regex-parser.h b/annotator/datetime/regex-parser.h new file mode 100644 index 0000000..7db04dc --- /dev/null +++ b/annotator/datetime/regex-parser.h
@@ -0,0 +1,122 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_REGEX_PARSER_H_ +#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_REGEX_PARSER_H_ + +#include <memory> +#include <string> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +#include "annotator/datetime/extractor.h" +#include "annotator/datetime/parser.h" +#include "annotator/model_generated.h" +#include "annotator/types.h" +#include "utils/base/integral_types.h" +#include "utils/base/statusor.h" +#include "utils/calendar/calendar.h" +#include "utils/strings/stringpiece.h" +#include "utils/utf8/unicodetext.h" +#include "utils/utf8/unilib.h" +#include "utils/zlib/tclib_zlib.h" + +namespace libtextclassifier3 { + +// Parses datetime expressions in the input and resolves them to actual absolute +// time. +class RegexDatetimeParser : public DatetimeParser { + public: + static std::unique_ptr<DatetimeParser> Instance( + const DatetimeModel* model, const UniLib* unilib, + const CalendarLib* calendarlib, ZlibDecompressor* decompressor); + + // Parses the dates in 'input' and fills result. Makes sure that the results + // do not overlap. + // If 'anchor_start_end' is true the extracted results need to start at the + // beginning of 'input' and end at the end of it. + StatusOr<std::vector<DatetimeParseResultSpan>> Parse( + const std::string& input, int64 reference_time_ms_utc, + const std::string& reference_timezone, const LocaleList& locale_list, + ModeFlag mode, AnnotationUsecase annotation_usecase, + bool anchor_start_end) const override; + + // Same as above but takes UnicodeText. + StatusOr<std::vector<DatetimeParseResultSpan>> Parse( + const UnicodeText& input, int64 reference_time_ms_utc, + const std::string& reference_timezone, const LocaleList& locale_list, + ModeFlag mode, AnnotationUsecase annotation_usecase, + bool anchor_start_end) const override; + + protected: + explicit RegexDatetimeParser(const DatetimeModel* model, const UniLib* unilib, + const CalendarLib* calendarlib, + ZlibDecompressor* decompressor); + + // Returns a list of locale ids for given locale spec string (collection of + // locale names). + std::vector<int> ParseAndExpandLocales( + const std::vector<StringPiece>& locales) const; + + // Helper function that finds datetime spans, only using the rules associated + // with the given locales. + StatusOr<std::vector<DatetimeParseResultSpan>> FindSpansUsingLocales( + const std::vector<int>& locale_ids, const UnicodeText& input, + const int64 reference_time_ms_utc, const std::string& reference_timezone, + ModeFlag mode, AnnotationUsecase annotation_usecase, + bool anchor_start_end, const std::string& reference_locale, + std::unordered_set<int>* executed_rules) const; + + StatusOr<std::vector<DatetimeParseResultSpan>> ParseWithRule( + const CompiledRule& rule, const UnicodeText& input, + int64 reference_time_ms_utc, const std::string& reference_timezone, + const std::string& reference_locale, const int locale_id, + bool anchor_start_end) const; + + // Converts the current match in 'matcher' into DatetimeParseResult. + bool ExtractDatetime(const CompiledRule& rule, + const UniLib::RegexMatcher& matcher, + int64 reference_time_ms_utc, + const std::string& reference_timezone, + const std::string& reference_locale, int locale_id, + std::vector<DatetimeParseResult>* results, + CodepointSpan* result_span) const; + + // Parse and extract information from current match in 'matcher'. + StatusOr<std::vector<DatetimeParseResultSpan>> HandleParseMatch( + const CompiledRule& rule, const UniLib::RegexMatcher& matcher, + int64 reference_time_ms_utc, const std::string& reference_timezone, + const std::string& reference_locale, int locale_id) const; + + private: + bool initialized_; + const UniLib& unilib_; + const CalendarLib& calendarlib_; + std::vector<CompiledRule> rules_; + std::unordered_map<int, std::vector<int>> locale_to_rules_; + std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_; + std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>> + type_and_locale_to_extractor_rule_; + std::unordered_map<std::string, int> locale_string_to_id_; + std::vector<int> default_locale_ids_; + bool use_extractors_for_locating_; + bool generate_alternative_interpretations_when_ambiguous_; + bool prefer_future_for_unspecified_date_; +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_REGEX_PARSER_H_ diff --git a/annotator/grammar/dates/annotations/annotation-options.h b/annotator/grammar/dates/annotations/annotation-options.h deleted file mode 100755 index 6c18ffd..0000000 --- a/annotator/grammar/dates/annotations/annotation-options.h +++ /dev/null
@@ -1,95 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_OPTIONS_H_ -#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_OPTIONS_H_ - -#include <string> -#include <vector> - -#include "utils/base/integral_types.h" - -namespace libtextclassifier3 { - -// Options for date/datetime/date range annotations. -struct DateAnnotationOptions { - // If enabled, extract special day offset like today, yesterday, etc. - bool enable_special_day_offset; - - // If true, merge the adjacent day of week, time and date. e.g. - // "20/2/2016 at 8pm" is extracted as a single instance instead of two - // instance: "20/2/2016" and "8pm". - bool merge_adjacent_components; - - // List the extra id of requested dates. - std::vector<std::string> extra_requested_dates; - - // If true, try to include preposition to the extracted annotation. e.g. - // "at 6pm". if it's false, only 6pm is included. offline-actions has special - // requirements to include preposition. - bool include_preposition; - - // The base timestamp (milliseconds) which used to convert relative time to - // absolute time. - // e.g.: - // base timestamp is 2016/4/25, then tomorrow will be converted to - // 2016/4/26. - // base timestamp is 2016/4/25 10:30:20am, then 1 days, 2 hours, 10 minutes - // and 5 seconds ago will be converted to 2016/4/24 08:20:15am - int64 base_timestamp_millis; - - // If enabled, extract range in date annotator. - // input: Monday, 5-6pm - // If the flag is true, The extracted annotation only contains 1 range - // instance which is from Monday 5pm to 6pm. - // If the flag is false, The extracted annotation contains two date - // instance: "Monday" and "6pm". - bool enable_date_range; - - // Timezone in which the input text was written - std::string reference_timezone; - // Localization params. - // The format of the locale lists should be "<lang_code-<county_code>" - // comma-separated list of two-character language/country pairs. - std::string locales; - - // If enabled, the annotation/rule_match priority score is used to set the and - // priority score of the annotation. - // In case of false the annotation priority score are set from - // GrammarDatetimeModel's priority_score - bool use_rule_priority_score; - - // If enabled, annotator will try to resolve the ambiguity by generating - // possible alternative interpretations of the input text - // e.g. '9:45' will be resolved to '9:45 AM' and '9:45 PM'. - bool generate_alternative_interpretations_when_ambiguous; - - // List the ignored span in the date string e.g. 12 March @12PM, here '@' - // can be ignored tokens. - std::vector<std::string> ignored_spans; - - // Default Constructor - DateAnnotationOptions() - : enable_special_day_offset(true), - merge_adjacent_components(true), - include_preposition(false), - base_timestamp_millis(0), - enable_date_range(false), - use_rule_priority_score(false), - generate_alternative_interpretations_when_ambiguous(false) {} -}; - -} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_OPTIONS_H_ diff --git a/annotator/grammar/dates/annotations/annotation-util.cc b/annotator/grammar/dates/annotations/annotation-util.cc deleted file mode 100644 index 9c45223..0000000 --- a/annotator/grammar/dates/annotations/annotation-util.cc +++ /dev/null
@@ -1,100 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#include "annotator/grammar/dates/annotations/annotation-util.h" - -#include <algorithm> - -#include "utils/base/logging.h" - -namespace libtextclassifier3 { - -int GetPropertyIndex(StringPiece name, const AnnotationData& annotation_data) { - for (int i = 0; i < annotation_data.properties.size(); ++i) { - if (annotation_data.properties[i].name == name.ToString()) { - return i; - } - } - return -1; -} - -int GetPropertyIndex(StringPiece name, const Annotation& annotation) { - return GetPropertyIndex(name, annotation.data); -} - -int GetIntProperty(StringPiece name, const Annotation& annotation) { - return GetIntProperty(name, annotation.data); -} - -int GetIntProperty(StringPiece name, const AnnotationData& annotation_data) { - const int index = GetPropertyIndex(name, annotation_data); - if (index < 0) { - TC3_DCHECK_GE(index, 0) - << "No property with name " << name.ToString() << "."; - return 0; - } - - if (annotation_data.properties.at(index).int_values.size() != 1) { - TC3_DCHECK_EQ(annotation_data.properties[index].int_values.size(), 1); - return 0; - } - - return annotation_data.properties.at(index).int_values.at(0); -} - -int AddIntProperty(StringPiece name, int value, Annotation* annotation) { - return AddRepeatedIntProperty(name, &value, 1, annotation); -} - -int AddIntProperty(StringPiece name, int value, - AnnotationData* annotation_data) { - return AddRepeatedIntProperty(name, &value, 1, annotation_data); -} - -int AddRepeatedIntProperty(StringPiece name, const int* start, int size, - Annotation* annotation) { - return AddRepeatedIntProperty(name, start, size, &annotation->data); -} - -int AddRepeatedIntProperty(StringPiece name, const int* start, int size, - AnnotationData* annotation_data) { - Property property; - property.name = name.ToString(); - auto first = start; - auto last = start + size; - while (first != last) { - property.int_values.push_back(*first); - first++; - } - annotation_data->properties.push_back(property); - return annotation_data->properties.size() - 1; -} - -int AddAnnotationDataProperty(const std::string& key, - const AnnotationData& value, - AnnotationData* annotation_data) { - Property property; - property.name = key; - property.annotation_data_values.push_back(value); - annotation_data->properties.push_back(property); - return annotation_data->properties.size() - 1; -} - -int AddAnnotationDataProperty(const std::string& key, - const AnnotationData& value, - Annotation* annotation) { - return AddAnnotationDataProperty(key, value, &annotation->data); -} -} // namespace libtextclassifier3 diff --git a/annotator/grammar/dates/annotations/annotation-util.h b/annotator/grammar/dates/annotations/annotation-util.h deleted file mode 100644 index bf60323..0000000 --- a/annotator/grammar/dates/annotations/annotation-util.h +++ /dev/null
@@ -1,74 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_UTIL_H_ -#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_UTIL_H_ - -#include "annotator/grammar/dates/annotations/annotation.h" -#include "utils/strings/stringpiece.h" - -namespace libtextclassifier3 { - -// Return the index of property in annotation.data().properties(). -// Return -1 if the property does not exist. -int GetPropertyIndex(StringPiece name, const Annotation& annotation); - -// Return the index of property in thing.properties(). -// Return -1 if the property does not exist. -int GetPropertyIndex(StringPiece name, const AnnotationData& annotation_data); - -// Return the single int value for property 'name' of the annotation. -// Returns 0 if the property does not exist or does not contain a single int -// value. -int GetIntProperty(StringPiece name, const Annotation& annotation); - -// Return the single float value for property 'name' of the annotation. -// Returns 0 if the property does not exist or does not contain a single int -// value. -int GetIntProperty(StringPiece name, const AnnotationData& annotation_data); - -// Add a new property with a single int value to an Annotation instance. -// Return the index of the property. -int AddIntProperty(StringPiece name, const int value, Annotation* annotation); - -// Add a new property with a single int value to a Thing instance. -// Return the index of the property. -int AddIntProperty(StringPiece name, const int value, - AnnotationData* annotation_data); - -// Add a new property with repeated int values to an Annotation instance. -// Return the index of the property. -int AddRepeatedIntProperty(StringPiece name, const int* start, int size, - Annotation* annotation); - -// Add a new property with repeated int values to a Thing instance. -// Return the index of the property. -int AddRepeatedIntProperty(StringPiece name, const int* start, int size, - AnnotationData* annotation_data); - -// Add a new property with Thing value. -// Return the index of the property. -int AddAnnotationDataProperty(const std::string& key, - const AnnotationData& value, - Annotation* annotation); - -// Add a new property with Thing value. -// Return the index of the property. -int AddAnnotationDataProperty(const std::string& key, - const AnnotationData& value, - AnnotationData* annotation_data); - -} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_UTIL_H_ diff --git a/annotator/grammar/dates/annotations/annotation.h b/annotator/grammar/dates/annotations/annotation.h deleted file mode 100644 index 1cbf598..0000000 --- a/annotator/grammar/dates/annotations/annotation.h +++ /dev/null
@@ -1,70 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_H_ -#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_H_ - -#include <string> -#include <vector> - -#include "utils/base/integral_types.h" - -namespace libtextclassifier3 { - -struct AnnotationData; - -// Define enum for each annotation. -enum GrammarAnnotationType { - // Date&time like "May 1", "12:20pm", etc. - DATETIME = 0, - // Datetime range like "2pm - 3pm". - DATETIME_RANGE = 1, -}; - -struct Property { - // TODO(hassan): Replace the name with enum e.g. PropertyType. - std::string name; - // At most one of these will have any values. - std::vector<bool> bool_values; - std::vector<int64> int_values; - std::vector<double> double_values; - std::vector<std::string> string_values; - std::vector<AnnotationData> annotation_data_values; -}; - -struct AnnotationData { - // TODO(hassan): Replace it type with GrammarAnnotationType - std::string type; - std::vector<Property> properties; -}; - -// Represents an annotation instance. -// lets call it either AnnotationDetails -struct Annotation { - // Codepoint offsets into the original text specifying the substring of the - // text that was annotated. - int32 begin; - int32 end; - - // Annotation priority score which can be used to resolve conflict between - // annotators. - float annotator_priority_score; - - // Represents the details of the annotation instance, including the type of - // the annotation instance and its properties. - AnnotationData data; -}; -} // namespace libtextclassifier3 -#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_ANNOTATIONS_ANNOTATION_H_ diff --git a/annotator/grammar/dates/cfg-datetime-annotator.cc b/annotator/grammar/dates/cfg-datetime-annotator.cc deleted file mode 100644 index 887f554..0000000 --- a/annotator/grammar/dates/cfg-datetime-annotator.cc +++ /dev/null
@@ -1,138 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#include "annotator/grammar/dates/cfg-datetime-annotator.h" - -#include "annotator/datetime/utils.h" -#include "annotator/grammar/dates/annotations/annotation-options.h" -#include "annotator/grammar/utils.h" -#include "utils/strings/split.h" -#include "utils/tokenizer.h" -#include "utils/utf8/unicodetext.h" - -namespace libtextclassifier3::dates { -namespace { - -static std::string GetReferenceLocale(const std::string& locales) { - std::vector<StringPiece> split_locales = strings::Split(locales, ','); - if (!split_locales.empty()) { - return split_locales[0].ToString(); - } - return ""; -} - -static void InterpretParseData(const DatetimeParsedData& datetime_parsed_data, - const DateAnnotationOptions& options, - const CalendarLib& calendarlib, - int64* interpreted_time_ms_utc, - DatetimeGranularity* granularity) { - DatetimeGranularity local_granularity = - calendarlib.GetGranularity(datetime_parsed_data); - if (!calendarlib.InterpretParseData( - datetime_parsed_data, options.base_timestamp_millis, - options.reference_timezone, GetReferenceLocale(options.locales), - /*prefer_future_for_unspecified_date=*/true, interpreted_time_ms_utc, - granularity)) { - TC3_LOG(WARNING) << "Failed to extract time in millis and Granularity."; - // Fallingback to DatetimeParsedData's finest granularity - *granularity = local_granularity; - } -} - -} // namespace - -CfgDatetimeAnnotator::CfgDatetimeAnnotator( - const UniLib* unilib, const GrammarTokenizerOptions* tokenizer_options, - const CalendarLib* calendar_lib, const DatetimeRules* datetime_rules, - const float annotator_target_classification_score, - const float annotator_priority_score) - : calendar_lib_(*calendar_lib), - tokenizer_(BuildTokenizer(unilib, tokenizer_options)), - parser_(unilib, datetime_rules), - annotator_target_classification_score_( - annotator_target_classification_score), - annotator_priority_score_(annotator_priority_score) {} - -void CfgDatetimeAnnotator::Parse( - const std::string& input, const DateAnnotationOptions& annotation_options, - const std::vector<Locale>& locales, - std::vector<DatetimeParseResultSpan>* results) const { - Parse(UTF8ToUnicodeText(input, /*do_copy=*/false), annotation_options, - locales, results); -} - -void CfgDatetimeAnnotator::ProcessDatetimeParseResult( - const DateAnnotationOptions& annotation_options, - const DatetimeParseResult& datetime_parse_result, - std::vector<DatetimeParseResult>* results) const { - DatetimeParsedData datetime_parsed_data; - datetime_parsed_data.AddDatetimeComponents( - datetime_parse_result.datetime_components); - - std::vector<DatetimeParsedData> interpretations; - if (annotation_options.generate_alternative_interpretations_when_ambiguous) { - FillInterpretations(datetime_parsed_data, - calendar_lib_.GetGranularity(datetime_parsed_data), - &interpretations); - } else { - interpretations.emplace_back(datetime_parsed_data); - } - for (const DatetimeParsedData& interpretation : interpretations) { - results->emplace_back(); - interpretation.GetDatetimeComponents(&results->back().datetime_components); - InterpretParseData(interpretation, annotation_options, calendar_lib_, - &(results->back().time_ms_utc), - &(results->back().granularity)); - std::sort(results->back().datetime_components.begin(), - results->back().datetime_components.end(), - [](const DatetimeComponent& a, const DatetimeComponent& b) { - return a.component_type > b.component_type; - }); - } -} - -void CfgDatetimeAnnotator::Parse( - const UnicodeText& input, const DateAnnotationOptions& annotation_options, - const std::vector<Locale>& locales, - std::vector<DatetimeParseResultSpan>* results) const { - std::vector<DatetimeParseResultSpan> grammar_datetime_parse_result_spans = - parser_.Parse(input.data(), tokenizer_.Tokenize(input), locales, - annotation_options); - - for (const DatetimeParseResultSpan& grammar_datetime_parse_result_span : - grammar_datetime_parse_result_spans) { - DatetimeParseResultSpan datetime_parse_result_span; - datetime_parse_result_span.span.first = - grammar_datetime_parse_result_span.span.first; - datetime_parse_result_span.span.second = - grammar_datetime_parse_result_span.span.second; - datetime_parse_result_span.priority_score = annotator_priority_score_; - if (annotation_options.use_rule_priority_score) { - datetime_parse_result_span.priority_score = - grammar_datetime_parse_result_span.priority_score; - } - datetime_parse_result_span.target_classification_score = - annotator_target_classification_score_; - for (const DatetimeParseResult& grammar_datetime_parse_result : - grammar_datetime_parse_result_span.data) { - ProcessDatetimeParseResult(annotation_options, - grammar_datetime_parse_result, - &datetime_parse_result_span.data); - } - results->emplace_back(datetime_parse_result_span); - } -} - -} // namespace libtextclassifier3::dates diff --git a/annotator/grammar/dates/cfg-datetime-annotator.h b/annotator/grammar/dates/cfg-datetime-annotator.h deleted file mode 100644 index 660cf76..0000000 --- a/annotator/grammar/dates/cfg-datetime-annotator.h +++ /dev/null
@@ -1,75 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#pragma GCC diagnostic ignored "-Wc++17-extensions" - -#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_CFG_DATETIME_ANNOTATOR_H_ -#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_CFG_DATETIME_ANNOTATOR_H_ - -#include "annotator/grammar/dates/annotations/annotation.h" -#include "annotator/grammar/dates/dates_generated.h" -#include "annotator/grammar/dates/parser.h" -#include "annotator/grammar/dates/utils/annotation-keys.h" -#include "annotator/model_generated.h" -#include "utils/calendar/calendar.h" -#include "utils/i18n/locale.h" -#include "utils/tokenizer.h" -#include "utils/utf8/unilib.h" - -namespace libtextclassifier3::dates { - -// Helper class to convert the parsed datetime expression from AnnotationList -// (List of annotation generated from Grammar rules) to DatetimeParseResultSpan. -class CfgDatetimeAnnotator { - public: - explicit CfgDatetimeAnnotator( - const UniLib* unilib, const GrammarTokenizerOptions* tokenizer_options, - const CalendarLib* calendar_lib, const DatetimeRules* datetime_rules, - const float annotator_target_classification_score, - const float annotator_priority_score); - - // CfgDatetimeAnnotator is neither copyable nor movable. - CfgDatetimeAnnotator(const CfgDatetimeAnnotator&) = delete; - CfgDatetimeAnnotator& operator=(const CfgDatetimeAnnotator&) = delete; - - // Parses the dates in 'input' and fills result. Makes sure that the results - // do not overlap. - // Method will return false if input does not contain any datetime span. - void Parse(const std::string& input, - const DateAnnotationOptions& annotation_options, - const std::vector<Locale>& locales, - std::vector<DatetimeParseResultSpan>* results) const; - - // UnicodeText version of parse. - void Parse(const UnicodeText& input, - const DateAnnotationOptions& annotation_options, - const std::vector<Locale>& locales, - std::vector<DatetimeParseResultSpan>* results) const; - - private: - void ProcessDatetimeParseResult( - const DateAnnotationOptions& annotation_options, - const DatetimeParseResult& datetime_parse_result, - std::vector<DatetimeParseResult>* results) const; - - const CalendarLib& calendar_lib_; - const Tokenizer tokenizer_; - DateParser parser_; - const float annotator_target_classification_score_; - const float annotator_priority_score_; -}; - -} // namespace libtextclassifier3::dates -#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_CFG_DATETIME_ANNOTATOR_H_ diff --git a/annotator/grammar/dates/dates.fbs b/annotator/grammar/dates/dates.fbs deleted file mode 100755 index 07e1964..0000000 --- a/annotator/grammar/dates/dates.fbs +++ /dev/null
@@ -1,350 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -include "annotator/grammar/dates/timezone-code.fbs"; -include "utils/grammar/rules.fbs"; - -// Type identifiers of all non-trivial matches. -namespace libtextclassifier3.dates; -enum MatchType : int { - UNKNOWN = 0, - - // Match of a date extraction rule. - DATETIME_RULE = 1, - - // Match of a date range extraction rule. - DATETIME_RANGE_RULE = 2, - - // Match defined by an ExtractionRule (e.g., a single time-result that is - // matched by a time-rule, which is ready to be output individually, with - // this kind of match, we can retrieve it in range rules). - DATETIME = 3, - - // Match defined by TermValue. - TERM_VALUE = 4, - - // Matches defined by Nonterminal. - NONTERMINAL = 5, - - DIGITS = 6, - YEAR = 7, - MONTH = 8, - DAY = 9, - HOUR = 10, - MINUTE = 11, - SECOND = 12, - FRACTION_SECOND = 13, - DAY_OF_WEEK = 14, - TIME_VALUE = 15, - TIME_SPAN = 16, - TIME_ZONE_NAME = 17, - TIME_ZONE_OFFSET = 18, - TIME_PERIOD = 19, - RELATIVE_DATE = 20, - COMBINED_DIGITS = 21, -} - -namespace libtextclassifier3.dates; -enum BCAD : int { - BCAD_NONE = -1, - BC = 0, - AD = 1, -} - -namespace libtextclassifier3.dates; -enum DayOfWeek : int { - DOW_NONE = -1, - SUNDAY = 1, - MONDAY = 2, - TUESDAY = 3, - WEDNESDAY = 4, - THURSDAY = 5, - FRIDAY = 6, - SATURDAY = 7, -} - -namespace libtextclassifier3.dates; -enum TimespanCode : int { - TIMESPAN_CODE_NONE = -1, - AM = 0, - PM = 1, - NOON = 2, - MIDNIGHT = 3, - - // English "tonight". - TONIGHT = 11, -} - -// The datetime grammar rules. -namespace libtextclassifier3.dates; -table DatetimeRules { - // The context free grammar rules. - rules:grammar.RulesSet; - - // Values associated with grammar rule matches. - extraction_rule:[ExtractionRuleParameter]; - - term_value:[TermValue]; - nonterminal_value:[NonterminalValue]; -} - -namespace libtextclassifier3.dates; -table TermValue { - value:int; - - // A time segment e.g. 10AM - 12AM - time_span_spec:TimeSpanSpec; - - // Time zone information representation - time_zone_name_spec:TimeZoneNameSpec; -} - -// Define nonterms from terms or other nonterms. -namespace libtextclassifier3.dates; -table NonterminalValue { - // Mapping value. - value:TermValue; - - // Parameter describing formatting choices for nonterminal messages - nonterminal_parameter:NonterminalParameter; - - // Parameter interpreting past/future dates (e.g. "last year") - relative_parameter:RelativeParameter; - - // Format info for nonterminals representing times. - time_value_parameter:TimeValueParameter; - - // Parameter describing the format of time-zone info - e.g. "UTC-8" - time_zone_offset_parameter:TimeZoneOffsetParameter; -} - -namespace libtextclassifier3.dates.RelativeParameter_; -enum RelativeType : int { - NONE = 0, - YEAR = 1, - MONTH = 2, - DAY = 3, - WEEK = 4, - HOUR = 5, - MINUTE = 6, - SECOND = 7, -} - -namespace libtextclassifier3.dates.RelativeParameter_; -enum Period : int { - PERIOD_UNKNOWN = 0, - PERIOD_PAST = 1, - PERIOD_FUTURE = 2, -} - -// Relative interpretation. -// Indicates which day the day of week could be, for example "next Friday" -// could means the Friday which is the closest Friday or the Friday in the -// next week. -namespace libtextclassifier3.dates.RelativeParameter_; -enum Interpretation : int { - UNKNOWN = 0, - - // The closest X in the past. - NEAREST_LAST = 1, - - // The X before the closest X in the past. - SECOND_LAST = 2, - - // The closest X in the future. - NEAREST_NEXT = 3, - - // The X after the closest X in the future. - SECOND_NEXT = 4, - - // X in the previous one. - PREVIOUS = 5, - - // X in the coming one. - COMING = 6, - - // X in current one, it can be both past and future. - CURRENT = 7, - - // Some X. - SOME = 8, - - // The closest X, it can be both past and future. - NEAREST = 9, -} - -namespace libtextclassifier3.dates; -table RelativeParameter { - type:RelativeParameter_.RelativeType = NONE; - period:RelativeParameter_.Period = PERIOD_UNKNOWN; - day_of_week_interpretation:[RelativeParameter_.Interpretation]; -} - -namespace libtextclassifier3.dates.NonterminalParameter_; -enum Flag : int { - IS_SPELLED = 1, -} - -namespace libtextclassifier3.dates; -table NonterminalParameter { - // Bit-wise OR Flag. - flag:uint = 0; - - combined_digits_format:string; -} - -namespace libtextclassifier3.dates.TimeValueParameter_; -enum TimeValueValidation : int { - // Allow extra spaces between sub-components in time-value. - ALLOW_EXTRA_SPACE = 1, - // 1 << 0 - - // Disallow colon- or dot-context with digits for time-value. - DISALLOW_COLON_DOT_CONTEXT = 2, - // 1 << 1 -} - -namespace libtextclassifier3.dates; -table TimeValueParameter { - validation:uint = 0; - // Bitwise-OR - - flag:uint = 0; - // Bitwise-OR -} - -namespace libtextclassifier3.dates.TimeZoneOffsetParameter_; -enum Format : int { - // Offset is in an uncategorized format. - FORMAT_UNKNOWN = 0, - - // Offset contains 1-digit hour only, e.g. "UTC-8". - FORMAT_H = 1, - - // Offset contains 2-digit hour only, e.g. "UTC-08". - FORMAT_HH = 2, - - // Offset contains 1-digit hour and minute, e.g. "UTC-8:00". - FORMAT_H_MM = 3, - - // Offset contains 2-digit hour and minute, e.g. "UTC-08:00". - FORMAT_HH_MM = 4, - - // Offset contains 3-digit hour-and-minute, e.g. "UTC-800". - FORMAT_HMM = 5, - - // Offset contains 4-digit hour-and-minute, e.g. "UTC-0800". - FORMAT_HHMM = 6, -} - -namespace libtextclassifier3.dates; -table TimeZoneOffsetParameter { - format:TimeZoneOffsetParameter_.Format = FORMAT_UNKNOWN; -} - -namespace libtextclassifier3.dates.ExtractionRuleParameter_; -enum ExtractionValidation : int { - // Boundary checking for final match. - LEFT_BOUND = 1, - - RIGHT_BOUND = 2, - SPELLED_YEAR = 4, - SPELLED_MONTH = 8, - SPELLED_DAY = 16, - - // Without this validation-flag set, unconfident time-zone expression - // are discarded in the output-callback, e.g. "-08:00, +8". - ALLOW_UNCONFIDENT_TIME_ZONE = 32, -} - -// Parameter info for extraction rule, help rule explanation. -namespace libtextclassifier3.dates; -table ExtractionRuleParameter { - // Bit-wise OR Validation. - validation:uint = 0; - - priority_delta:int; - id:string; - - // The score reflects the confidence score of the date/time match, which is - // set while creating grammar rules. - // e.g. given we have the rule which detect "22.33" as a HH.MM then because - // of ambiguity the confidence of this match maybe relatively less. - annotator_priority_score:float; -} - -// Internal structure used to describe an hour-mapping segment. -namespace libtextclassifier3.dates.TimeSpanSpec_; -table Segment { - // From 0 to 24, the beginning hour of the segment, always included. - begin:int; - - // From 0 to 24, the ending hour of the segment, not included if the - // segment is not closed. The value 0 means the beginning of the next - // day, the same value as "begin" means a time-point. - end:int; - - // From -24 to 24, the mapping offset in hours from spanned expressions - // to 24-hour expressions. The value 0 means identical mapping. - offset:int; - - // True if the segment is a closed one instead of a half-open one. - // Always set it to true when describing time-points. - is_closed:bool = false; - - // True if a strict check should be performed onto the segment which - // disallows already-offset hours to be used in spanned expressions, - // e.g. 15:30PM. - is_strict:bool = false; - - // True if the time-span can be used without an explicitly specified - // hour value, then it can generate an exact time point (the "begin" - // o'clock sharp, like "noon") or a time range, like "Tonight". - is_stand_alone:bool = false; -} - -namespace libtextclassifier3.dates; -table TimeSpanSpec { - code:TimespanCode; - segment:[TimeSpanSpec_.Segment]; -} - -namespace libtextclassifier3.dates.TimeZoneNameSpec_; -enum TimeZoneType : int { - // The corresponding name might represent a standard or daylight-saving - // time-zone, depending on some external information, e.g. the date. - AMBIGUOUS = 0, - - // The corresponding name represents a standard time-zone. - STANDARD = 1, - - // The corresponding name represents a daylight-saving time-zone. - DAYLIGHT = 2, -} - -namespace libtextclassifier3.dates; -table TimeZoneNameSpec { - code:TimezoneCode; - type:TimeZoneNameSpec_.TimeZoneType = AMBIGUOUS; - - // Set to true if the corresponding name is internationally used as an - // abbreviation (or expression) of UTC. For example, "GMT" and "Z". - is_utc:bool = false; - - // Set to false if the corresponding name is not an abbreviation. For example, - // "Pacific Time" and "China Standard Time". - is_abbreviation:bool = true; -} - diff --git a/annotator/grammar/dates/extractor.cc b/annotator/grammar/dates/extractor.cc deleted file mode 100644 index 8f11937..0000000 --- a/annotator/grammar/dates/extractor.cc +++ /dev/null
@@ -1,912 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#include "annotator/grammar/dates/extractor.h" - -#include <initializer_list> -#include <map> - -#include "annotator/grammar/dates/utils/date-match.h" -#include "annotator/grammar/dates/utils/date-utils.h" -#include "utils/base/casts.h" -#include "utils/base/logging.h" -#include "utils/strings/numbers.h" - -namespace libtextclassifier3::dates { -namespace { - -// Helper struct for time-related components. -// Extracts all subnodes of a specified type. -struct MatchComponents { - MatchComponents(const grammar::Match* root, - std::initializer_list<int16> types) - : root(root), - components(grammar::SelectAll( - root, [root, &types](const grammar::Match* node) { - if (node == root || node->type == grammar::Match::kUnknownType) { - return false; - } - for (const int64 type : types) { - if (node->type == type) { - return true; - } - } - return false; - })) {} - - // Returns the index of the first submatch of the specified type or -1 if not - // found. - int IndexOf(const int16 type, const int start_index = 0) const { - for (int i = start_index; i < components.size(); i++) { - if (components[i]->type == type) { - return i; - } - } - return -1; - } - - // Returns the first submatch of the specified type, or nullptr if not found. - template <typename T> - const T* SubmatchOf(const int16 type, const int start_index = 0) const { - return SubmatchAt<T>(IndexOf(type, start_index)); - } - - template <typename T> - const T* SubmatchAt(const int index) const { - if (index < 0) { - return nullptr; - } - return static_cast<const T*>(components[index]); - } - - const grammar::Match* root; - std::vector<const grammar::Match*> components; -}; - -// Helper method to check whether a time value has valid components. -bool IsValidTimeValue(const TimeValueMatch& time_value) { - // Can only specify seconds if minutes are present. - if (time_value.minute == NO_VAL && time_value.second != NO_VAL) { - return false; - } - // Can only specify fraction of seconds if seconds are present. - if (time_value.second == NO_VAL && time_value.fraction_second >= 0.0) { - return false; - } - - const int8 h = time_value.hour; - const int8 m = (time_value.minute < 0 ? 0 : time_value.minute); - const int8 s = (time_value.second < 0 ? 0 : time_value.second); - const double f = - (time_value.fraction_second < 0.0 ? 0.0 : time_value.fraction_second); - - // Check value bounds. - if (h == NO_VAL || h > 24 || m > 59 || s > 60) { - return false; - } - if (h == 24 && (m != 0 || s != 0 || f > 0.0)) { - return false; - } - if (s == 60 && m != 59) { - return false; - } - return true; -} - -int ParseLeadingDec32Value(const char* c_str) { - int value; - if (ParseInt32(c_str, &value)) { - return value; - } - return NO_VAL; -} - -double ParseLeadingDoubleValue(const char* c_str) { - double value; - if (ParseDouble(c_str, &value)) { - return value; - } - return NO_VAL; -} - -// Extracts digits as an integer and adds a typed match accordingly. -template <typename T> -void CheckDigits(const grammar::Match* match, - const NonterminalValue* nonterminal, StringPiece match_text, - grammar::Matcher* matcher) { - TC3_CHECK(match->IsUnaryRule()); - const int value = ParseLeadingDec32Value(match_text.ToString().c_str()); - if (!T::IsValid(value)) { - return; - } - const int num_digits = match_text.size(); - T* result = matcher->AllocateAndInitMatch<T>( - match->lhs, match->codepoint_span, match->match_offset); - result->Reset(); - result->nonterminal = nonterminal; - result->value = value; - result->count_of_digits = num_digits; - result->is_zero_prefixed = (num_digits >= 2 && match_text[0] == '0'); - matcher->AddMatch(result); -} - -// Extracts digits as a decimal (as fraction, as if a "0." is prefixed) and -// adds a typed match to the `er accordingly. -template <typename T> -void CheckDigitsAsFraction(const grammar::Match* match, - const NonterminalValue* nonterminal, - StringPiece match_text, grammar::Matcher* matcher) { - TC3_CHECK(match->IsUnaryRule()); - // TODO(smillius): Should should be achievable in a more straight-forward way. - const double value = - ParseLeadingDoubleValue(("0." + match_text.ToString()).data()); - if (!T::IsValid(value)) { - return; - } - T* result = matcher->AllocateAndInitMatch<T>( - match->lhs, match->codepoint_span, match->match_offset); - result->Reset(); - result->nonterminal = nonterminal; - result->value = value; - result->count_of_digits = match_text.size(); - matcher->AddMatch(result); -} - -// Extracts consecutive digits as multiple integers according to a format and -// adds a type match to the matcher accordingly. -template <typename T> -void CheckCombinedDigits(const grammar::Match* match, - const NonterminalValue* nonterminal, - StringPiece match_text, grammar::Matcher* matcher) { - TC3_CHECK(match->IsUnaryRule()); - const std::string& format = - nonterminal->nonterminal_parameter()->combined_digits_format()->str(); - if (match_text.size() != format.size()) { - return; - } - - static std::map<char, CombinedDigitsMatch::Index>& kCombinedDigitsMatchIndex = - *[]() { - return new std::map<char, CombinedDigitsMatch::Index>{ - {'Y', CombinedDigitsMatch::INDEX_YEAR}, - {'M', CombinedDigitsMatch::INDEX_MONTH}, - {'D', CombinedDigitsMatch::INDEX_DAY}, - {'h', CombinedDigitsMatch::INDEX_HOUR}, - {'m', CombinedDigitsMatch::INDEX_MINUTE}, - {'s', CombinedDigitsMatch::INDEX_SECOND}}; - }(); - - struct Segment { - const int index; - const int length; - const int value; - }; - std::vector<Segment> segments; - int slice_start = 0; - while (slice_start < format.size()) { - int slice_end = slice_start + 1; - // Advace right as long as we have the same format character. - while (slice_end < format.size() && - format[slice_start] == format[slice_end]) { - slice_end++; - } - - const int slice_length = slice_end - slice_start; - const int value = ParseLeadingDec32Value( - std::string(match_text.data() + slice_start, slice_length).c_str()); - - auto index = kCombinedDigitsMatchIndex.find(format[slice_start]); - if (index == kCombinedDigitsMatchIndex.end()) { - return; - } - if (!T::IsValid(index->second, value)) { - return; - } - segments.push_back(Segment{index->second, slice_length, value}); - slice_start = slice_end; - } - T* result = matcher->AllocateAndInitMatch<T>( - match->lhs, match->codepoint_span, match->match_offset); - result->Reset(); - result->nonterminal = nonterminal; - for (const Segment& segment : segments) { - result->values[segment.index] = segment.value; - } - result->count_of_digits = match_text.size(); - result->is_zero_prefixed = - (match_text[0] == '0' && segments.front().length >= 2); - matcher->AddMatch(result); -} - -// Retrieves the corresponding value from an associated term-value mapping for -// the nonterminal and adds a typed match to the matcher accordingly. -template <typename T> -void CheckMappedValue(const grammar::Match* match, - const NonterminalValue* nonterminal, - grammar::Matcher* matcher) { - const TermValueMatch* term = - grammar::SelectFirstOfType<TermValueMatch>(match, MatchType_TERM_VALUE); - if (term == nullptr) { - return; - } - const int value = term->term_value->value(); - if (!T::IsValid(value)) { - return; - } - T* result = matcher->AllocateAndInitMatch<T>( - match->lhs, match->codepoint_span, match->match_offset); - result->Reset(); - result->nonterminal = nonterminal; - result->value = value; - matcher->AddMatch(result); -} - -// Checks if there is an associated value in the corresponding nonterminal and -// adds a typed match to the matcher accordingly. -template <typename T> -void CheckDirectValue(const grammar::Match* match, - const NonterminalValue* nonterminal, - grammar::Matcher* matcher) { - const int value = nonterminal->value()->value(); - if (!T::IsValid(value)) { - return; - } - T* result = matcher->AllocateAndInitMatch<T>( - match->lhs, match->codepoint_span, match->match_offset); - result->Reset(); - result->nonterminal = nonterminal; - result->value = value; - matcher->AddMatch(result); -} - -template <typename T> -void CheckAndAddDirectOrMappedValue(const grammar::Match* match, - const NonterminalValue* nonterminal, - grammar::Matcher* matcher) { - if (nonterminal->value() != nullptr) { - CheckDirectValue<T>(match, nonterminal, matcher); - } else { - CheckMappedValue<T>(match, nonterminal, matcher); - } -} - -template <typename T> -void CheckAndAddNumericValue(const grammar::Match* match, - const NonterminalValue* nonterminal, - StringPiece match_text, - grammar::Matcher* matcher) { - if (nonterminal->nonterminal_parameter() != nullptr && - nonterminal->nonterminal_parameter()->flag() & - NonterminalParameter_::Flag_IS_SPELLED) { - CheckMappedValue<T>(match, nonterminal, matcher); - } else { - CheckDigits<T>(match, nonterminal, match_text, matcher); - } -} - -// Tries to parse as digital time value. -bool ParseDigitalTimeValue(const std::vector<UnicodeText::const_iterator>& text, - const MatchComponents& components, - const NonterminalValue* nonterminal, - grammar::Matcher* matcher) { - // Required fields. - const HourMatch* hour = components.SubmatchOf<HourMatch>(MatchType_HOUR); - if (hour == nullptr || hour->count_of_digits == 0) { - return false; - } - - // Optional fields. - const MinuteMatch* minute = - components.SubmatchOf<MinuteMatch>(MatchType_MINUTE); - if (minute != nullptr && minute->count_of_digits == 0) { - return false; - } - const SecondMatch* second = - components.SubmatchOf<SecondMatch>(MatchType_SECOND); - if (second != nullptr && second->count_of_digits == 0) { - return false; - } - const FractionSecondMatch* fraction_second = - components.SubmatchOf<FractionSecondMatch>(MatchType_FRACTION_SECOND); - if (fraction_second != nullptr && fraction_second->count_of_digits == 0) { - return false; - } - - // Validation. - uint32 validation = nonterminal->time_value_parameter()->validation(); - const grammar::Match* end = hour; - if (minute != nullptr) { - if (second != nullptr) { - if (fraction_second != nullptr) { - end = fraction_second; - } else { - end = second; - } - } else { - end = minute; - } - } - - // Check if there is any extra space between h m s f. - if ((validation & - TimeValueParameter_::TimeValueValidation_ALLOW_EXTRA_SPACE) == 0) { - // Check whether there is whitespace between token. - if (minute != nullptr && minute->HasLeadingWhitespace()) { - return false; - } - if (second != nullptr && second->HasLeadingWhitespace()) { - return false; - } - if (fraction_second != nullptr && fraction_second->HasLeadingWhitespace()) { - return false; - } - } - - // Check if there is any ':' or '.' as a prefix or suffix. - if (validation & - TimeValueParameter_::TimeValueValidation_DISALLOW_COLON_DOT_CONTEXT) { - const int begin_pos = hour->codepoint_span.first; - const int end_pos = end->codepoint_span.second; - if (begin_pos > 1 && - (*text[begin_pos - 1] == ':' || *text[begin_pos - 1] == '.') && - isdigit(*text[begin_pos - 2])) { - return false; - } - // Last valid codepoint is at text.size() - 2 as we added the end position - // of text for easier span extraction. - if (end_pos < text.size() - 2 && - (*text[end_pos] == ':' || *text[end_pos] == '.') && - isdigit(*text[end_pos + 1])) { - return false; - } - } - - TimeValueMatch time_value; - time_value.Init(components.root->lhs, components.root->codepoint_span, - components.root->match_offset); - time_value.Reset(); - time_value.hour_match = hour; - time_value.minute_match = minute; - time_value.second_match = second; - time_value.fraction_second_match = fraction_second; - time_value.is_hour_zero_prefixed = hour->is_zero_prefixed; - time_value.is_minute_one_digit = - (minute != nullptr && minute->count_of_digits == 1); - time_value.is_second_one_digit = - (second != nullptr && second->count_of_digits == 1); - time_value.hour = hour->value; - time_value.minute = (minute != nullptr ? minute->value : NO_VAL); - time_value.second = (second != nullptr ? second->value : NO_VAL); - time_value.fraction_second = - (fraction_second != nullptr ? fraction_second->value : NO_VAL); - - if (!IsValidTimeValue(time_value)) { - return false; - } - - TimeValueMatch* result = matcher->AllocateMatch<TimeValueMatch>(); - *result = time_value; - matcher->AddMatch(result); - return true; -} - -// Tries to parsing a time from spelled out time components. -bool ParseSpelledTimeValue(const MatchComponents& components, - const NonterminalValue* nonterminal, - grammar::Matcher* matcher) { - // Required fields. - const HourMatch* hour = components.SubmatchOf<HourMatch>(MatchType_HOUR); - if (hour == nullptr || hour->count_of_digits != 0) { - return false; - } - // Optional fields. - const MinuteMatch* minute = - components.SubmatchOf<MinuteMatch>(MatchType_MINUTE); - if (minute != nullptr && minute->count_of_digits != 0) { - return false; - } - const SecondMatch* second = - components.SubmatchOf<SecondMatch>(MatchType_SECOND); - if (second != nullptr && second->count_of_digits != 0) { - return false; - } - - uint32 validation = nonterminal->time_value_parameter()->validation(); - // Check if there is any extra space between h m s. - if ((validation & - TimeValueParameter_::TimeValueValidation_ALLOW_EXTRA_SPACE) == 0) { - // Check whether there is whitespace between token. - if (minute != nullptr && minute->HasLeadingWhitespace()) { - return false; - } - if (second != nullptr && second->HasLeadingWhitespace()) { - return false; - } - } - - TimeValueMatch time_value; - time_value.Init(components.root->lhs, components.root->codepoint_span, - components.root->match_offset); - time_value.Reset(); - time_value.hour_match = hour; - time_value.minute_match = minute; - time_value.second_match = second; - time_value.is_hour_zero_prefixed = hour->is_zero_prefixed; - time_value.is_minute_one_digit = - (minute != nullptr && minute->count_of_digits == 1); - time_value.is_second_one_digit = - (second != nullptr && second->count_of_digits == 1); - time_value.hour = hour->value; - time_value.minute = (minute != nullptr ? minute->value : NO_VAL); - time_value.second = (second != nullptr ? second->value : NO_VAL); - - if (!IsValidTimeValue(time_value)) { - return false; - } - - TimeValueMatch* result = matcher->AllocateMatch<TimeValueMatch>(); - *result = time_value; - matcher->AddMatch(result); - return true; -} - -// Reconstructs and validates a time value from a match. -void CheckTimeValue(const std::vector<UnicodeText::const_iterator>& text, - const grammar::Match* match, - const NonterminalValue* nonterminal, - grammar::Matcher* matcher) { - MatchComponents components( - match, {MatchType_HOUR, MatchType_MINUTE, MatchType_SECOND, - MatchType_FRACTION_SECOND}); - if (ParseDigitalTimeValue(text, components, nonterminal, matcher)) { - return; - } - if (ParseSpelledTimeValue(components, nonterminal, matcher)) { - return; - } -} - -// Validates a time span match. -void CheckTimeSpan(const grammar::Match* match, - const NonterminalValue* nonterminal, - grammar::Matcher* matcher) { - const TermValueMatch* ts_name = - grammar::SelectFirstOfType<TermValueMatch>(match, MatchType_TERM_VALUE); - const TermValue* term_value = ts_name->term_value; - TC3_CHECK(term_value != nullptr); - TC3_CHECK(term_value->time_span_spec() != nullptr); - const TimeSpanSpec* ts_spec = term_value->time_span_spec(); - TimeSpanMatch* time_span = matcher->AllocateAndInitMatch<TimeSpanMatch>( - match->lhs, match->codepoint_span, match->match_offset); - time_span->Reset(); - time_span->nonterminal = nonterminal; - time_span->time_span_spec = ts_spec; - time_span->time_span_code = ts_spec->code(); - matcher->AddMatch(time_span); -} - -// Validates a time period match. -void CheckTimePeriod(const std::vector<UnicodeText::const_iterator>& text, - const grammar::Match* match, - const NonterminalValue* nonterminal, - grammar::Matcher* matcher) { - int period_value = NO_VAL; - - // If a value mapping exists, use it. - if (nonterminal->value() != nullptr) { - period_value = nonterminal->value()->value(); - } else if (const TermValueMatch* term = - grammar::SelectFirstOfType<TermValueMatch>( - match, MatchType_TERM_VALUE)) { - period_value = term->term_value->value(); - } else if (const grammar::Match* digits = - grammar::SelectFirstOfType<grammar::Match>( - match, grammar::Match::kDigitsType)) { - period_value = ParseLeadingDec32Value( - std::string(text[digits->codepoint_span.first].utf8_data(), - text[digits->codepoint_span.second].utf8_data() - - text[digits->codepoint_span.first].utf8_data()) - .c_str()); - } - - if (period_value <= NO_VAL) { - return; - } - - TimePeriodMatch* result = matcher->AllocateAndInitMatch<TimePeriodMatch>( - match->lhs, match->codepoint_span, match->match_offset); - result->Reset(); - result->nonterminal = nonterminal; - result->value = period_value; - matcher->AddMatch(result); -} - -// Reconstructs a date from a relative date rule match. -void CheckRelativeDate(const DateAnnotationOptions& options, - const grammar::Match* match, - const NonterminalValue* nonterminal, - grammar::Matcher* matcher) { - if (!options.enable_special_day_offset && - grammar::SelectFirstOfType<TermValueMatch>(match, MatchType_TERM_VALUE) != - nullptr) { - // Special day offsets, like "Today", "Tomorrow" etc. are not enabled. - return; - } - - RelativeMatch* relative_match = matcher->AllocateAndInitMatch<RelativeMatch>( - match->lhs, match->codepoint_span, match->match_offset); - relative_match->Reset(); - relative_match->nonterminal = nonterminal; - - // Fill relative date information from individual components. - grammar::Traverse(match, [match, relative_match](const grammar::Match* node) { - // Ignore the current match. - if (node == match || node->type == grammar::Match::kUnknownType) { - return true; - } - - if (node->type == MatchType_TERM_VALUE) { - const int value = - static_cast<const TermValueMatch*>(node)->term_value->value(); - relative_match->day = abs(value); - if (value >= 0) { - // Marks "today" as in the future. - relative_match->is_future_date = true; - } - relative_match->existing |= - (RelativeMatch::HAS_DAY | RelativeMatch::HAS_IS_FUTURE); - return false; - } - - // Parse info from nonterminal. - const NonterminalValue* nonterminal = - static_cast<const NonterminalMatch*>(node)->nonterminal; - if (nonterminal != nullptr && - nonterminal->relative_parameter() != nullptr) { - const RelativeParameter* relative_parameter = - nonterminal->relative_parameter(); - if (relative_parameter->period() != - RelativeParameter_::Period_PERIOD_UNKNOWN) { - relative_match->is_future_date = - (relative_parameter->period() == - RelativeParameter_::Period_PERIOD_FUTURE); - relative_match->existing |= RelativeMatch::HAS_IS_FUTURE; - } - if (relative_parameter->day_of_week_interpretation() != nullptr) { - relative_match->day_of_week_nonterminal = nonterminal; - relative_match->existing |= RelativeMatch::HAS_DAY_OF_WEEK; - } - } - - // Relative day of week. - if (node->type == MatchType_DAY_OF_WEEK) { - relative_match->day_of_week = - static_cast<const DayOfWeekMatch*>(node)->value; - return false; - } - - if (node->type != MatchType_TIME_PERIOD) { - return true; - } - - const TimePeriodMatch* period = static_cast<const TimePeriodMatch*>(node); - switch (nonterminal->relative_parameter()->type()) { - case RelativeParameter_::RelativeType_YEAR: { - relative_match->year = period->value; - relative_match->existing |= RelativeMatch::HAS_YEAR; - break; - } - case RelativeParameter_::RelativeType_MONTH: { - relative_match->month = period->value; - relative_match->existing |= RelativeMatch::HAS_MONTH; - break; - } - case RelativeParameter_::RelativeType_WEEK: { - relative_match->week = period->value; - relative_match->existing |= RelativeMatch::HAS_WEEK; - break; - } - case RelativeParameter_::RelativeType_DAY: { - relative_match->day = period->value; - relative_match->existing |= RelativeMatch::HAS_DAY; - break; - } - case RelativeParameter_::RelativeType_HOUR: { - relative_match->hour = period->value; - relative_match->existing |= RelativeMatch::HAS_HOUR; - break; - } - case RelativeParameter_::RelativeType_MINUTE: { - relative_match->minute = period->value; - relative_match->existing |= RelativeMatch::HAS_MINUTE; - break; - } - case RelativeParameter_::RelativeType_SECOND: { - relative_match->second = period->value; - relative_match->existing |= RelativeMatch::HAS_SECOND; - break; - } - default: - break; - } - - return true; - }); - matcher->AddMatch(relative_match); -} - -bool IsValidTimeZoneOffset(const int time_zone_offset) { - return (time_zone_offset >= -720 && time_zone_offset <= 840 && - time_zone_offset % 15 == 0); -} - -// Parses, validates and adds a time zone offset match. -void CheckTimeZoneOffset(const grammar::Match* match, - const NonterminalValue* nonterminal, - grammar::Matcher* matcher) { - MatchComponents components( - match, {MatchType_DIGITS, MatchType_TERM_VALUE, MatchType_NONTERMINAL}); - const TermValueMatch* tz_sign = - components.SubmatchOf<TermValueMatch>(MatchType_TERM_VALUE); - if (tz_sign == nullptr) { - return; - } - const int sign = tz_sign->term_value->value(); - TC3_CHECK(sign == -1 || sign == 1); - - const int tz_digits_index = components.IndexOf(MatchType_DIGITS); - if (tz_digits_index < 0) { - return; - } - const DigitsMatch* tz_digits = - components.SubmatchAt<DigitsMatch>(tz_digits_index); - if (tz_digits == nullptr) { - return; - } - - int offset; - if (tz_digits->count_of_digits >= 3) { - offset = (tz_digits->value / 100) * 60 + (tz_digits->value % 100); - } else { - offset = tz_digits->value * 60; - if (const DigitsMatch* tz_digits_extra = components.SubmatchOf<DigitsMatch>( - MatchType_DIGITS, /*start_index=*/tz_digits_index + 1)) { - offset += tz_digits_extra->value; - } - } - - const NonterminalMatch* tz_offset = - components.SubmatchOf<NonterminalMatch>(MatchType_NONTERMINAL); - if (tz_offset == nullptr) { - return; - } - - const int time_zone_offset = sign * offset; - if (!IsValidTimeZoneOffset(time_zone_offset)) { - return; - } - - TimeZoneOffsetMatch* result = - matcher->AllocateAndInitMatch<TimeZoneOffsetMatch>( - match->lhs, match->codepoint_span, match->match_offset); - result->Reset(); - result->nonterminal = nonterminal; - result->time_zone_offset_param = - tz_offset->nonterminal->time_zone_offset_parameter(); - result->time_zone_offset = time_zone_offset; - matcher->AddMatch(result); -} - -// Validates and adds a time zone name match. -void CheckTimeZoneName(const grammar::Match* match, - const NonterminalValue* nonterminal, - grammar::Matcher* matcher) { - TC3_CHECK(match->IsUnaryRule()); - const TermValueMatch* tz_name = - static_cast<const TermValueMatch*>(match->unary_rule_rhs()); - if (tz_name == nullptr) { - return; - } - const TimeZoneNameSpec* tz_name_spec = - tz_name->term_value->time_zone_name_spec(); - TimeZoneNameMatch* result = matcher->AllocateAndInitMatch<TimeZoneNameMatch>( - match->lhs, match->codepoint_span, match->match_offset); - result->Reset(); - result->nonterminal = nonterminal; - result->time_zone_name_spec = tz_name_spec; - result->time_zone_code = tz_name_spec->code(); - matcher->AddMatch(result); -} - -// Adds a mapped term value match containing its value. -void AddTermValue(const grammar::Match* match, const TermValue* term_value, - grammar::Matcher* matcher) { - TermValueMatch* term_match = matcher->AllocateAndInitMatch<TermValueMatch>( - match->lhs, match->codepoint_span, match->match_offset); - term_match->Reset(); - term_match->term_value = term_value; - matcher->AddMatch(term_match); -} - -// Adds a match for a nonterminal. -void AddNonterminal(const grammar::Match* match, - const NonterminalValue* nonterminal, - grammar::Matcher* matcher) { - NonterminalMatch* result = - matcher->AllocateAndInitMatch<NonterminalMatch>(*match); - result->Reset(); - result->nonterminal = nonterminal; - matcher->AddMatch(result); -} - -// Adds a match for an extraction rule that is potentially used in a date range -// rule. -void AddExtractionRuleMatch(const grammar::Match* match, - const ExtractionRuleParameter* rule, - grammar::Matcher* matcher) { - ExtractionMatch* result = - matcher->AllocateAndInitMatch<ExtractionMatch>(*match); - result->Reset(); - result->extraction_rule = rule; - matcher->AddMatch(result); -} - -} // namespace - -void DateExtractor::HandleExtractionRuleMatch( - const ExtractionRuleParameter* rule, const grammar::Match* match, - grammar::Matcher* matcher) { - if (rule->id() != nullptr) { - const std::string rule_id = rule->id()->str(); - bool keep = false; - for (const std::string& extra_requested_dates_id : - options_.extra_requested_dates) { - if (extra_requested_dates_id == rule_id) { - keep = true; - break; - } - } - if (!keep) { - return; - } - } - output_.push_back( - Output{rule, matcher->AllocateAndInitMatch<grammar::Match>(*match)}); -} - -void DateExtractor::HandleRangeExtractionRuleMatch(const grammar::Match* match, - grammar::Matcher* matcher) { - // Collect the two datetime roots that make up the range. - std::vector<const grammar::Match*> parts; - grammar::Traverse(match, [match, &parts](const grammar::Match* node) { - if (node == match || node->type == grammar::Match::kUnknownType) { - // Just continue traversing the match. - return true; - } - - // Collect, but don't expand the individual datetime nodes. - parts.push_back(node); - return false; - }); - TC3_CHECK_EQ(parts.size(), 2); - range_output_.push_back( - RangeOutput{matcher->AllocateAndInitMatch<grammar::Match>(*match), - /*from=*/parts[0], /*to=*/parts[1]}); -} - -void DateExtractor::MatchFound(const grammar::Match* match, - const grammar::CallbackId type, - const int64 value, grammar::Matcher* matcher) { - switch (type) { - case MatchType_DATETIME_RULE: { - HandleExtractionRuleMatch( - /*rule=*/ - datetime_rules_->extraction_rule()->Get(value), match, matcher); - return; - } - case MatchType_DATETIME_RANGE_RULE: { - HandleRangeExtractionRuleMatch(match, matcher); - return; - } - case MatchType_DATETIME: { - // If an extraction rule is also part of a range extraction rule, then the - // extraction rule is treated as a rule match and nonterminal match. - // This type is used to match the rule as non terminal. - AddExtractionRuleMatch( - match, datetime_rules_->extraction_rule()->Get(value), matcher); - return; - } - case MatchType_TERM_VALUE: { - // Handle mapped terms. - AddTermValue(match, datetime_rules_->term_value()->Get(value), matcher); - return; - } - default: - break; - } - - // Handle non-terminals. - const NonterminalValue* nonterminal = - datetime_rules_->nonterminal_value()->Get(value); - StringPiece match_text = - StringPiece(text_[match->codepoint_span.first].utf8_data(), - text_[match->codepoint_span.second].utf8_data() - - text_[match->codepoint_span.first].utf8_data()); - switch (type) { - case MatchType_NONTERMINAL: - AddNonterminal(match, nonterminal, matcher); - break; - case MatchType_DIGITS: - CheckDigits<DigitsMatch>(match, nonterminal, match_text, matcher); - break; - case MatchType_YEAR: - CheckDigits<YearMatch>(match, nonterminal, match_text, matcher); - break; - case MatchType_MONTH: - CheckAndAddNumericValue<MonthMatch>(match, nonterminal, match_text, - matcher); - break; - case MatchType_DAY: - CheckAndAddNumericValue<DayMatch>(match, nonterminal, match_text, - matcher); - break; - case MatchType_DAY_OF_WEEK: - CheckAndAddDirectOrMappedValue<DayOfWeekMatch>(match, nonterminal, - matcher); - break; - case MatchType_HOUR: - CheckAndAddNumericValue<HourMatch>(match, nonterminal, match_text, - matcher); - break; - case MatchType_MINUTE: - CheckAndAddNumericValue<MinuteMatch>(match, nonterminal, match_text, - matcher); - break; - case MatchType_SECOND: - CheckAndAddNumericValue<SecondMatch>(match, nonterminal, match_text, - matcher); - break; - case MatchType_FRACTION_SECOND: - CheckDigitsAsFraction<FractionSecondMatch>(match, nonterminal, match_text, - matcher); - break; - case MatchType_TIME_VALUE: - CheckTimeValue(text_, match, nonterminal, matcher); - break; - case MatchType_TIME_SPAN: - CheckTimeSpan(match, nonterminal, matcher); - break; - case MatchType_TIME_ZONE_NAME: - CheckTimeZoneName(match, nonterminal, matcher); - break; - case MatchType_TIME_ZONE_OFFSET: - CheckTimeZoneOffset(match, nonterminal, matcher); - break; - case MatchType_TIME_PERIOD: - CheckTimePeriod(text_, match, nonterminal, matcher); - break; - case MatchType_RELATIVE_DATE: - CheckRelativeDate(options_, match, nonterminal, matcher); - break; - case MatchType_COMBINED_DIGITS: - CheckCombinedDigits<CombinedDigitsMatch>(match, nonterminal, match_text, - matcher); - break; - default: - TC3_VLOG(ERROR) << "Unhandled match type: " << type; - } -} - -} // namespace libtextclassifier3::dates diff --git a/annotator/grammar/dates/extractor.h b/annotator/grammar/dates/extractor.h deleted file mode 100644 index a2658d5..0000000 --- a/annotator/grammar/dates/extractor.h +++ /dev/null
@@ -1,86 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_EXTRACTOR_H_ -#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_EXTRACTOR_H_ - -#include <vector> - -#include "annotator/grammar/dates/annotations/annotation-options.h" -#include "annotator/grammar/dates/dates_generated.h" -#include "utils/base/integral_types.h" -#include "utils/grammar/callback-delegate.h" -#include "utils/grammar/match.h" -#include "utils/grammar/matcher.h" -#include "utils/grammar/types.h" -#include "utils/strings/stringpiece.h" -#include "utils/utf8/unicodetext.h" - -namespace libtextclassifier3::dates { - -// A helper class for the datetime parser that extracts structured data from -// the datetime grammar matches. -// It handles simple sanity checking of the rule matches and interacts with the -// grammar matcher to extract all datetime occurrences in a text. -class DateExtractor : public grammar::CallbackDelegate { - public: - // Represents a date match for an extraction rule. - struct Output { - const ExtractionRuleParameter* rule = nullptr; - const grammar::Match* match = nullptr; - }; - - // Represents a date match from a range extraction rule. - struct RangeOutput { - const grammar::Match* match = nullptr; - const grammar::Match* from = nullptr; - const grammar::Match* to = nullptr; - }; - - DateExtractor(const std::vector<UnicodeText::const_iterator>& text, - const DateAnnotationOptions& options, - const DatetimeRules* datetime_rules) - : text_(text), options_(options), datetime_rules_(datetime_rules) {} - - // Handle a rule match in the date time grammar. - // This checks the type of the match and does type dependent checks. - void MatchFound(const grammar::Match* match, grammar::CallbackId type, - int64 value, grammar::Matcher* matcher) override; - - const std::vector<Output>& output() const { return output_; } - const std::vector<RangeOutput>& range_output() const { return range_output_; } - - private: - // Extracts a date from a root rule match. - void HandleExtractionRuleMatch(const ExtractionRuleParameter* rule, - const grammar::Match* match, - grammar::Matcher* matcher); - - // Extracts a date range from a root rule match. - void HandleRangeExtractionRuleMatch(const grammar::Match* match, - grammar::Matcher* matcher); - - const std::vector<UnicodeText::const_iterator>& text_; - const DateAnnotationOptions& options_; - const DatetimeRules* datetime_rules_; - - // Extraction results. - std::vector<Output> output_; - std::vector<RangeOutput> range_output_; -}; - -} // namespace libtextclassifier3::dates - -#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_EXTRACTOR_H_ diff --git a/annotator/grammar/dates/parser.cc b/annotator/grammar/dates/parser.cc deleted file mode 100644 index 8c2527b..0000000 --- a/annotator/grammar/dates/parser.cc +++ /dev/null
@@ -1,793 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#include "annotator/grammar/dates/parser.h" - -#include "annotator/grammar/dates/extractor.h" -#include "annotator/grammar/dates/utils/date-match.h" -#include "annotator/grammar/dates/utils/date-utils.h" -#include "utils/base/integral_types.h" -#include "utils/base/logging.h" -#include "utils/base/macros.h" -#include "utils/grammar/lexer.h" -#include "utils/grammar/matcher.h" -#include "utils/grammar/rules_generated.h" -#include "utils/grammar/types.h" -#include "utils/strings/split.h" -#include "utils/strings/stringpiece.h" - -namespace libtextclassifier3::dates { -namespace { - -// Helper methods to validate individual components from a date match. - -// Checks the validation requirement of a rule against a match. -// For example if the rule asks for `SPELLED_MONTH`, then we check that the -// match has the right flag. -bool CheckMatchValidationAndFlag( - const grammar::Match* match, const ExtractionRuleParameter* rule, - const ExtractionRuleParameter_::ExtractionValidation validation, - const NonterminalParameter_::Flag flag) { - if (rule == nullptr || (rule->validation() & validation) == 0) { - // No validation requirement. - return true; - } - const NonterminalParameter* nonterminal_parameter = - static_cast<const NonterminalMatch*>(match) - ->nonterminal->nonterminal_parameter(); - return (nonterminal_parameter != nullptr && - (nonterminal_parameter->flag() & flag) != 0); -} - -bool GenerateDate(const ExtractionRuleParameter* rule, - const grammar::Match* match, DateMatch* date) { - bool is_valid = true; - - // Post check and assign date components. - grammar::Traverse(match, [rule, date, &is_valid](const grammar::Match* node) { - switch (node->type) { - case MatchType_YEAR: { - if (CheckMatchValidationAndFlag( - node, rule, - ExtractionRuleParameter_::ExtractionValidation_SPELLED_YEAR, - NonterminalParameter_::Flag_IS_SPELLED)) { - date->year_match = static_cast<const YearMatch*>(node); - date->year = date->year_match->value; - } else { - is_valid = false; - } - break; - } - case MatchType_MONTH: { - if (CheckMatchValidationAndFlag( - node, rule, - ExtractionRuleParameter_::ExtractionValidation_SPELLED_MONTH, - NonterminalParameter_::Flag_IS_SPELLED)) { - date->month_match = static_cast<const MonthMatch*>(node); - date->month = date->month_match->value; - } else { - is_valid = false; - } - break; - } - case MatchType_DAY: { - if (CheckMatchValidationAndFlag( - node, rule, - ExtractionRuleParameter_::ExtractionValidation_SPELLED_DAY, - NonterminalParameter_::Flag_IS_SPELLED)) { - date->day_match = static_cast<const DayMatch*>(node); - date->day = date->day_match->value; - } else { - is_valid = false; - } - break; - } - case MatchType_DAY_OF_WEEK: { - date->day_of_week_match = static_cast<const DayOfWeekMatch*>(node); - date->day_of_week = - static_cast<DayOfWeek>(date->day_of_week_match->value); - break; - } - case MatchType_TIME_VALUE: { - date->time_value_match = static_cast<const TimeValueMatch*>(node); - date->hour = date->time_value_match->hour; - date->minute = date->time_value_match->minute; - date->second = date->time_value_match->second; - date->fraction_second = date->time_value_match->fraction_second; - return false; - } - case MatchType_TIME_SPAN: { - date->time_span_match = static_cast<const TimeSpanMatch*>(node); - date->time_span_code = date->time_span_match->time_span_code; - return false; - } - case MatchType_TIME_ZONE_NAME: { - date->time_zone_name_match = - static_cast<const TimeZoneNameMatch*>(node); - date->time_zone_code = date->time_zone_name_match->time_zone_code; - return false; - } - case MatchType_TIME_ZONE_OFFSET: { - date->time_zone_offset_match = - static_cast<const TimeZoneOffsetMatch*>(node); - date->time_zone_offset = date->time_zone_offset_match->time_zone_offset; - return false; - } - case MatchType_RELATIVE_DATE: { - date->relative_match = static_cast<const RelativeMatch*>(node); - return false; - } - case MatchType_COMBINED_DIGITS: { - date->combined_digits_match = - static_cast<const CombinedDigitsMatch*>(node); - if (date->combined_digits_match->HasYear()) { - date->year = date->combined_digits_match->GetYear(); - } - if (date->combined_digits_match->HasMonth()) { - date->month = date->combined_digits_match->GetMonth(); - } - if (date->combined_digits_match->HasDay()) { - date->day = date->combined_digits_match->GetDay(); - } - if (date->combined_digits_match->HasHour()) { - date->hour = date->combined_digits_match->GetHour(); - } - if (date->combined_digits_match->HasMinute()) { - date->minute = date->combined_digits_match->GetMinute(); - } - if (date->combined_digits_match->HasSecond()) { - date->second = date->combined_digits_match->GetSecond(); - } - return false; - } - default: - // Expand node further. - return true; - } - - return false; - }); - - if (is_valid) { - date->begin = match->codepoint_span.first; - date->end = match->codepoint_span.second; - date->priority = rule ? rule->priority_delta() : 0; - date->annotator_priority_score = - rule ? rule->annotator_priority_score() : 0.0; - } - return is_valid; -} - -bool GenerateFromOrToDateRange(const grammar::Match* match, DateMatch* date) { - return GenerateDate( - /*rule=*/( - match->type == MatchType_DATETIME - ? static_cast<const ExtractionMatch*>(match)->extraction_rule - : nullptr), - match, date); -} - -bool GenerateDateRange(const grammar::Match* match, const grammar::Match* from, - const grammar::Match* to, DateRangeMatch* date_range) { - if (!GenerateFromOrToDateRange(from, &date_range->from)) { - TC3_LOG(WARNING) << "Failed to generate date for `from`."; - return false; - } - if (!GenerateFromOrToDateRange(to, &date_range->to)) { - TC3_LOG(WARNING) << "Failed to generate date for `to`."; - return false; - } - date_range->begin = match->codepoint_span.first; - date_range->end = match->codepoint_span.second; - return true; -} - -bool NormalizeHour(DateMatch* date) { - if (date->time_span_match == nullptr) { - // Nothing to do. - return true; - } - return NormalizeHourByTimeSpan(date->time_span_match->time_span_spec, date); -} - -void CheckAndSetAmbiguousHour(DateMatch* date) { - if (date->HasHour()) { - // Use am-pm ambiguity as default. - if (!date->HasTimeSpanCode() && date->hour >= 1 && date->hour <= 12 && - !(date->time_value_match != nullptr && - date->time_value_match->hour_match != nullptr && - date->time_value_match->hour_match->is_zero_prefixed)) { - date->SetAmbiguousHourProperties(2, 12); - } - } -} - -// Normalizes a date candidate. -// Returns whether the candidate was successfully normalized. -bool NormalizeDate(DateMatch* date) { - // Normalize hour. - if (!NormalizeHour(date)) { - TC3_VLOG(ERROR) << "Hour normalization (according to time-span) failed." - << date->DebugString(); - return false; - } - CheckAndSetAmbiguousHour(date); - if (!date->IsValid()) { - TC3_VLOG(ERROR) << "Fields inside date instance are ill-formed " - << date->DebugString(); - } - return true; -} - -// Copies the field from one DateMatch to another whose field is null. for -// example: if the from is "May 1, 8pm", and the to is "9pm", "May 1" will be -// copied to "to". Now we only copy fields for date range requirement.fv -void CopyFieldsForDateMatch(const DateMatch& from, DateMatch* to) { - if (from.time_span_match != nullptr && to->time_span_match == nullptr) { - to->time_span_match = from.time_span_match; - to->time_span_code = from.time_span_code; - } - if (from.month_match != nullptr && to->month_match == nullptr) { - to->month_match = from.month_match; - to->month = from.month; - } -} - -// Normalizes a date range candidate. -// Returns whether the date range was successfully normalized. -bool NormalizeDateRange(DateRangeMatch* date_range) { - CopyFieldsForDateMatch(date_range->from, &date_range->to); - CopyFieldsForDateMatch(date_range->to, &date_range->from); - return (NormalizeDate(&date_range->from) && NormalizeDate(&date_range->to)); -} - -bool CheckDate(const DateMatch& date, const ExtractionRuleParameter* rule) { - // It's possible that "time_zone_name_match == NULL" when - // "HasTimeZoneCode() == true", or "time_zone_offset_match == NULL" when - // "HasTimeZoneOffset() == true" due to inference between endpoints, so we - // must check if they really exist before using them. - if (date.HasTimeZoneOffset()) { - if (date.HasTimeZoneCode()) { - if (date.time_zone_name_match != nullptr) { - TC3_CHECK(date.time_zone_name_match->time_zone_name_spec != nullptr); - const TimeZoneNameSpec* spec = - date.time_zone_name_match->time_zone_name_spec; - if (!spec->is_utc()) { - return false; - } - if (!spec->is_abbreviation()) { - return false; - } - } - } else if (date.time_zone_offset_match != nullptr) { - TC3_CHECK(date.time_zone_offset_match->time_zone_offset_param != nullptr); - const TimeZoneOffsetParameter* param = - date.time_zone_offset_match->time_zone_offset_param; - if (param->format() == TimeZoneOffsetParameter_::Format_FORMAT_H || - param->format() == TimeZoneOffsetParameter_::Format_FORMAT_HH) { - return false; - } - if (!(rule->validation() & - ExtractionRuleParameter_:: - ExtractionValidation_ALLOW_UNCONFIDENT_TIME_ZONE)) { - if (param->format() == TimeZoneOffsetParameter_::Format_FORMAT_H_MM || - param->format() == TimeZoneOffsetParameter_::Format_FORMAT_HH_MM || - param->format() == TimeZoneOffsetParameter_::Format_FORMAT_HMM) { - return false; - } - } - } - } - - // Case: 1 April could be extracted as year 1, month april. - // We simply remove this case. - if (!date.HasBcAd() && date.year_match != nullptr && date.year < 1000) { - // We allow case like 11/5/01 - if (date.HasMonth() && date.HasDay() && - date.year_match->count_of_digits == 2) { - } else { - return false; - } - } - - // Ignore the date if the year is larger than 9999 (The maximum number of 4 - // digits). - if (date.year_match != nullptr && date.year > 9999) { - TC3_VLOG(ERROR) << "Year is greater than 9999."; - return false; - } - - // Case: spelled may could be month 5, it also used very common as modal - // verbs. We ignore spelled may as month. - if ((rule->validation() & - ExtractionRuleParameter_::ExtractionValidation_SPELLED_MONTH) && - date.month == 5 && !date.HasYear() && !date.HasDay()) { - return false; - } - - return true; -} - -bool CheckContext(const std::vector<UnicodeText::const_iterator>& text, - const DateExtractor::Output& output) { - const uint32 validation = output.rule->validation(); - - // Nothing to check if we don't have any validation requirements for the - // span boundaries. - if ((validation & - (ExtractionRuleParameter_::ExtractionValidation_LEFT_BOUND | - ExtractionRuleParameter_::ExtractionValidation_RIGHT_BOUND)) == 0) { - return true; - } - - const int begin = output.match->codepoint_span.first; - const int end = output.match->codepoint_span.second; - - // So far, we only check that the adjacent character cannot be a separator, - // like /, - or . - if ((validation & - ExtractionRuleParameter_::ExtractionValidation_LEFT_BOUND) != 0) { - if (begin > 0 && (*text[begin - 1] == '/' || *text[begin - 1] == '-' || - *text[begin - 1] == ':')) { - return false; - } - } - if ((validation & - ExtractionRuleParameter_::ExtractionValidation_RIGHT_BOUND) != 0) { - // Last valid codepoint is at text.size() - 2 as we added the end position - // of text for easier span extraction. - if (end < text.size() - 1 && - (*text[end] == '/' || *text[end] == '-' || *text[end] == ':')) { - return false; - } - } - - return true; -} - -// Validates a date match. Returns true if the candidate is valid. -bool ValidateDate(const std::vector<UnicodeText::const_iterator>& text, - const DateExtractor::Output& output, const DateMatch& date) { - if (!CheckDate(date, output.rule)) { - return false; - } - if (!CheckContext(text, output)) { - return false; - } - return true; -} - -// Builds matched date instances from the grammar output. -std::vector<DateMatch> BuildDateMatches( - const std::vector<UnicodeText::const_iterator>& text, - const std::vector<DateExtractor::Output>& outputs) { - std::vector<DateMatch> result; - for (const DateExtractor::Output& output : outputs) { - DateMatch date; - if (GenerateDate(output.rule, output.match, &date)) { - if (!NormalizeDate(&date)) { - continue; - } - if (!ValidateDate(text, output, date)) { - continue; - } - result.push_back(date); - } - } - return result; -} - -// Builds matched date range instances from the grammar output. -std::vector<DateRangeMatch> BuildDateRangeMatches( - const std::vector<UnicodeText::const_iterator>& text, - const std::vector<DateExtractor::RangeOutput>& range_outputs) { - std::vector<DateRangeMatch> result; - for (const DateExtractor::RangeOutput& range_output : range_outputs) { - DateRangeMatch date_range; - if (GenerateDateRange(range_output.match, range_output.from, - range_output.to, &date_range)) { - if (!NormalizeDateRange(&date_range)) { - continue; - } - result.push_back(date_range); - } - } - return result; -} - -template <typename T> -void RemoveDeletedMatches(const std::vector<bool>& removed, - std::vector<T>* matches) { - int input = 0; - for (int next = 0; next < matches->size(); ++next) { - if (removed[next]) { - continue; - } - if (input != next) { - (*matches)[input] = (*matches)[next]; - } - input++; - } - matches->resize(input); -} - -// Removes duplicated date or date range instances. -// Overlapping date and date ranges are not considered here. -template <typename T> -void RemoveDuplicatedDates(std::vector<T>* matches) { - // Assumption: matches are sorted ascending by (begin, end). - std::vector<bool> removed(matches->size(), false); - for (int i = 0; i < matches->size(); i++) { - if (removed[i]) { - continue; - } - const T& candidate = matches->at(i); - for (int j = i + 1; j < matches->size(); j++) { - if (removed[j]) { - continue; - } - const T& next = matches->at(j); - - // Not overlapping. - if (next.begin >= candidate.end) { - break; - } - - // If matching the same span of text, then check the priority. - if (candidate.begin == next.begin && candidate.end == next.end) { - if (candidate.GetPriority() < next.GetPriority()) { - removed[i] = true; - break; - } else { - removed[j] = true; - continue; - } - } - - // Checks if `next` is fully covered by fields of `candidate`. - if (next.end <= candidate.end) { - removed[j] = true; - continue; - } - - // Checks whether `candidate`/`next` is a refinement. - if (IsRefinement(candidate, next)) { - removed[j] = true; - continue; - } else if (IsRefinement(next, candidate)) { - removed[i] = true; - break; - } - } - } - RemoveDeletedMatches(removed, matches); -} - -// Filters out simple overtriggering simple matches. -bool IsBlacklistedDate(const UniLib& unilib, - const std::vector<UnicodeText::const_iterator>& text, - const DateMatch& match) { - const int begin = match.begin; - const int end = match.end; - if (end - begin != 3) { - return false; - } - - std::string text_lower = - unilib - .ToLowerText( - UTF8ToUnicodeText(text[begin].utf8_data(), - text[end].utf8_data() - text[begin].utf8_data(), - /*do_copy=*/false)) - .ToUTF8String(); - - // "sun" is not a good abbreviation for a standalone day of the week. - if (match.IsStandaloneRelativeDayOfWeek() && - (text_lower == "sun" || text_lower == "mon")) { - return true; - } - - // "mar" is not a good abbreviation for single month. - if (match.HasMonth() && text_lower == "mar") { - return true; - } - - return false; -} - -// Checks if two date matches are adjacent and mergeable. -bool AreDateMatchesAdjacentAndMergeable( - const UniLib& unilib, const std::vector<UnicodeText::const_iterator>& text, - const std::vector<std::string>& ignored_spans, const DateMatch& prev, - const DateMatch& next) { - // Check the context between the two matches. - if (next.begin <= prev.end) { - // The two matches are not adjacent. - return false; - } - UnicodeText span; - for (int i = prev.end; i < next.begin; i++) { - const char32 codepoint = *text[i]; - if (unilib.IsWhitespace(codepoint)) { - continue; - } - span.push_back(unilib.ToLower(codepoint)); - } - if (span.empty()) { - return true; - } - const std::string span_text = span.ToUTF8String(); - bool matched = false; - for (const std::string& ignored_span : ignored_spans) { - if (span_text == ignored_span) { - matched = true; - break; - } - } - if (!matched) { - return false; - } - return IsDateMatchMergeable(prev, next); -} - -// Merges adjacent date and date range. -// For e.g. Monday, 5-10pm, the date "Monday" and the time range "5-10pm" will -// be merged -void MergeDateRangeAndDate(const UniLib& unilib, - const std::vector<UnicodeText::const_iterator>& text, - const std::vector<std::string>& ignored_spans, - const std::vector<DateMatch>& dates, - std::vector<DateRangeMatch>* date_ranges) { - // For each range, check the date before or after the it to see if they could - // be merged. Both the range and date array are sorted, so we only need to - // scan the date array once. - int next_date = 0; - for (int i = 0; i < date_ranges->size(); i++) { - DateRangeMatch* date_range = &date_ranges->at(i); - // So far we only merge time range with a date. - if (!date_range->from.HasHour()) { - continue; - } - - for (; next_date < dates.size(); next_date++) { - const DateMatch& date = dates[next_date]; - - // If the range is before the date, we check whether `date_range->to` can - // be merged with the date. - if (date_range->end <= date.begin) { - DateMatch merged_date = date; - if (AreDateMatchesAdjacentAndMergeable(unilib, text, ignored_spans, - date_range->to, date)) { - MergeDateMatch(date_range->to, &merged_date, /*update_span=*/true); - date_range->to = merged_date; - date_range->end = date_range->to.end; - MergeDateMatch(date, &date_range->from, /*update_span=*/false); - next_date++; - - // Check the second date after the range to see if it could be merged - // further. For example: 10-11pm, Monday, May 15. 10-11pm is merged - // with Monday and then we check that it could be merged with May 15 - // as well. - if (next_date < dates.size()) { - DateMatch next_match = dates[next_date]; - if (AreDateMatchesAdjacentAndMergeable( - unilib, text, ignored_spans, date_range->to, next_match)) { - MergeDateMatch(date_range->to, &next_match, /*update_span=*/true); - date_range->to = next_match; - date_range->end = date_range->to.end; - MergeDateMatch(dates[next_date], &date_range->from, - /*update_span=*/false); - next_date++; - } - } - } - // Since the range is before the date, we try to check if the next range - // could be merged with the current date. - break; - } else if (date_range->end > date.end && date_range->begin > date.begin) { - // If the range is after the date, we check if `date_range.from` can be - // merged with the date. Here is a special case, the date before range - // could be partially overlapped. This is because the range.from could - // be extracted as year in date. For example: March 3, 10-11pm is - // extracted as date March 3, 2010 and the range 10-11pm. In this - // case, we simply clear the year from date. - DateMatch merged_date = date; - if (date.HasYear() && - date.year_match->codepoint_span.second > date_range->begin) { - merged_date.year_match = nullptr; - merged_date.year = NO_VAL; - merged_date.end = date.year_match->match_offset; - } - // Check and merge the range and the date before the range. - if (AreDateMatchesAdjacentAndMergeable(unilib, text, ignored_spans, - merged_date, date_range->from)) { - MergeDateMatch(merged_date, &date_range->from, /*update_span=*/true); - date_range->begin = date_range->from.begin; - MergeDateMatch(merged_date, &date_range->to, /*update_span=*/false); - - // Check if the second date before the range can be merged as well. - if (next_date > 0) { - DateMatch prev_match = dates[next_date - 1]; - if (prev_match.end <= date_range->from.begin) { - if (AreDateMatchesAdjacentAndMergeable(unilib, text, - ignored_spans, prev_match, - date_range->from)) { - MergeDateMatch(prev_match, &date_range->from, - /*update_span=*/true); - date_range->begin = date_range->from.begin; - MergeDateMatch(prev_match, &date_range->to, - /*update_span=*/false); - } - } - } - next_date++; - break; - } else { - // Since the date is before the date range, we move to the next date - // to check if it could be merged with the current range. - continue; - } - } else { - // The date is either fully overlapped by the date range or the date - // span end is after the date range. Move to the next date in both - // cases. - } - } - } -} - -// Removes the dates which are part of a range. e.g. in "May 1 - 3", the date -// "May 1" is fully contained in the range. -void RemoveOverlappedDateByRange(const std::vector<DateRangeMatch>& ranges, - std::vector<DateMatch>* dates) { - int next_date = 0; - std::vector<bool> removed(dates->size(), false); - for (int i = 0; i < ranges.size(); ++i) { - const auto& range = ranges[i]; - for (; next_date < dates->size(); ++next_date) { - const auto& date = dates->at(next_date); - // So far we don't touch the partially overlapped case. - if (date.begin >= range.begin && date.end <= range.end) { - // Fully contained. - removed[next_date] = true; - } else if (date.end <= range.begin) { - continue; // date is behind range, go to next date - } else if (date.begin >= range.end) { - break; // range is behind date, go to next range - } - } - } - RemoveDeletedMatches(removed, dates); -} - -// Converts candidate dates and date ranges. -void FillDateInstances( - const UniLib& unilib, const std::vector<UnicodeText::const_iterator>& text, - const DateAnnotationOptions& options, std::vector<DateMatch>* date_matches, - std::vector<DatetimeParseResultSpan>* datetime_parse_result_spans) { - int i = 0; - for (int j = 1; j < date_matches->size(); j++) { - if (options.merge_adjacent_components && - AreDateMatchesAdjacentAndMergeable(unilib, text, options.ignored_spans, - date_matches->at(i), - date_matches->at(j))) { - MergeDateMatch(date_matches->at(i), &date_matches->at(j), true); - } else { - if (!IsBlacklistedDate(unilib, text, date_matches->at(i))) { - DatetimeParseResultSpan datetime_parse_result_span; - FillDateInstance(date_matches->at(i), &datetime_parse_result_span); - datetime_parse_result_spans->push_back(datetime_parse_result_span); - } - } - i = j; - } - if (!IsBlacklistedDate(unilib, text, date_matches->at(i))) { - DatetimeParseResultSpan datetime_parse_result_span; - FillDateInstance(date_matches->at(i), &datetime_parse_result_span); - datetime_parse_result_spans->push_back(datetime_parse_result_span); - } -} - -void FillDateRangeInstances( - const std::vector<DateRangeMatch>& date_range_matches, - std::vector<DatetimeParseResultSpan>* datetime_parse_result_spans) { - for (const DateRangeMatch& date_range_match : date_range_matches) { - DatetimeParseResultSpan datetime_parse_result_span; - FillDateRangeInstance(date_range_match, &datetime_parse_result_span); - datetime_parse_result_spans->push_back(datetime_parse_result_span); - } -} - -// Fills `DatetimeParseResultSpan` from `DateMatch` and `DateRangeMatch` -// instances. -std::vector<DatetimeParseResultSpan> GetOutputAsAnnotationList( - const UniLib& unilib, const DateExtractor& extractor, - const std::vector<UnicodeText::const_iterator>& text, - const DateAnnotationOptions& options) { - std::vector<DatetimeParseResultSpan> datetime_parse_result_spans; - std::vector<DateMatch> date_matches = - BuildDateMatches(text, extractor.output()); - - std::sort( - date_matches.begin(), date_matches.end(), - // Order by increasing begin, and decreasing end (decreasing length). - [](const DateMatch& a, const DateMatch& b) { - return (a.begin < b.begin || (a.begin == b.begin && a.end > b.end)); - }); - - if (!date_matches.empty()) { - RemoveDuplicatedDates(&date_matches); - } - - if (options.enable_date_range) { - std::vector<DateRangeMatch> date_range_matches = - BuildDateRangeMatches(text, extractor.range_output()); - - if (!date_range_matches.empty()) { - std::sort( - date_range_matches.begin(), date_range_matches.end(), - // Order by increasing begin, and decreasing end (decreasing length). - [](const DateRangeMatch& a, const DateRangeMatch& b) { - return (a.begin < b.begin || (a.begin == b.begin && a.end > b.end)); - }); - RemoveDuplicatedDates(&date_range_matches); - } - - if (!date_matches.empty()) { - MergeDateRangeAndDate(unilib, text, options.ignored_spans, date_matches, - &date_range_matches); - RemoveOverlappedDateByRange(date_range_matches, &date_matches); - } - FillDateRangeInstances(date_range_matches, &datetime_parse_result_spans); - } - - if (!date_matches.empty()) { - FillDateInstances(unilib, text, options, &date_matches, - &datetime_parse_result_spans); - } - return datetime_parse_result_spans; -} - -} // namespace - -std::vector<DatetimeParseResultSpan> DateParser::Parse( - StringPiece text, const std::vector<Token>& tokens, - const std::vector<Locale>& locales, - const DateAnnotationOptions& options) const { - std::vector<UnicodeText::const_iterator> codepoint_offsets; - const UnicodeText text_unicode = UTF8ToUnicodeText(text, - /*do_copy=*/false); - for (auto it = text_unicode.begin(); it != text_unicode.end(); it++) { - codepoint_offsets.push_back(it); - } - codepoint_offsets.push_back(text_unicode.end()); - DateExtractor extractor(codepoint_offsets, options, datetime_rules_); - // Select locale matching rules. - // Only use a shard if locales match or the shard doesn't specify a locale - // restriction. - std::vector<const grammar::RulesSet_::Rules*> locale_rules = - SelectLocaleMatchingShards(datetime_rules_->rules(), rules_locales_, - locales); - if (locale_rules.empty()) { - return {}; - } - grammar::Matcher matcher(&unilib_, datetime_rules_->rules(), locale_rules, - &extractor); - lexer_.Process(text_unicode, tokens, /*annotations=*/nullptr, &matcher); - return GetOutputAsAnnotationList(unilib_, extractor, codepoint_offsets, - options); -} - -} // namespace libtextclassifier3::dates diff --git a/annotator/grammar/dates/parser.h b/annotator/grammar/dates/parser.h deleted file mode 100644 index 020c76f..0000000 --- a/annotator/grammar/dates/parser.h +++ /dev/null
@@ -1,65 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#pragma GCC diagnostic ignored "-Wc++17-extensions" - -#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_PARSER_H_ -#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_PARSER_H_ - -#include <vector> - -#include "annotator/grammar/dates/annotations/annotation-options.h" -#include "annotator/grammar/dates/annotations/annotation.h" -#include "annotator/grammar/dates/dates_generated.h" -#include "annotator/grammar/dates/utils/date-match.h" -#include "utils/grammar/lexer.h" -#include "utils/grammar/rules-utils.h" -#include "utils/i18n/locale.h" -#include "utils/strings/stringpiece.h" -#include "utils/utf8/unilib.h" - -namespace libtextclassifier3::dates { - -// Parses datetime expressions in the input with the datetime grammar and -// constructs, validates, deduplicates and normalizes date time annotations. -class DateParser { - public: - explicit DateParser(const UniLib* unilib, const DatetimeRules* datetime_rules) - : unilib_(*unilib), - lexer_(unilib, datetime_rules->rules()), - datetime_rules_(datetime_rules), - rules_locales_(ParseRulesLocales(datetime_rules->rules())) {} - - // Parses the dates in the input. Makes sure that the results do not - // overlap. - std::vector<DatetimeParseResultSpan> Parse( - StringPiece text, const std::vector<Token>& tokens, - const std::vector<Locale>& locales, - const DateAnnotationOptions& options) const; - - private: - const UniLib& unilib_; - const grammar::Lexer lexer_; - - // The datetime grammar. - const DatetimeRules* datetime_rules_; - - // Pre-parsed locales of the rules. - const std::vector<std::vector<Locale>> rules_locales_; -}; - -} // namespace libtextclassifier3::dates - -#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_PARSER_H_ diff --git a/annotator/grammar/dates/timezone-code.fbs b/annotator/grammar/dates/timezone-code.fbs deleted file mode 100755 index ae69885..0000000 --- a/annotator/grammar/dates/timezone-code.fbs +++ /dev/null
@@ -1,592 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -namespace libtextclassifier3.dates; -enum TimezoneCode : int { - TIMEZONE_CODE_NONE = -1, - ETC_UNKNOWN = 0, - PST8PDT = 1, - // Delegate. - - AFRICA_ABIDJAN = 2, - AFRICA_ACCRA = 3, - AFRICA_ADDIS_ABABA = 4, - AFRICA_ALGIERS = 5, - AFRICA_ASMARA = 6, - AFRICA_BAMAKO = 7, - // Delegate. - - AFRICA_BANGUI = 8, - AFRICA_BANJUL = 9, - AFRICA_BISSAU = 10, - AFRICA_BLANTYRE = 11, - AFRICA_BRAZZAVILLE = 12, - AFRICA_BUJUMBURA = 13, - EGYPT = 14, - // Delegate. - - AFRICA_CASABLANCA = 15, - AFRICA_CEUTA = 16, - AFRICA_CONAKRY = 17, - AFRICA_DAKAR = 18, - AFRICA_DAR_ES_SALAAM = 19, - AFRICA_DJIBOUTI = 20, - AFRICA_DOUALA = 21, - AFRICA_EL_AAIUN = 22, - AFRICA_FREETOWN = 23, - AFRICA_GABORONE = 24, - AFRICA_HARARE = 25, - AFRICA_JOHANNESBURG = 26, - AFRICA_KAMPALA = 27, - AFRICA_KHARTOUM = 28, - AFRICA_KIGALI = 29, - AFRICA_KINSHASA = 30, - AFRICA_LAGOS = 31, - AFRICA_LIBREVILLE = 32, - AFRICA_LOME = 33, - AFRICA_LUANDA = 34, - AFRICA_LUBUMBASHI = 35, - AFRICA_LUSAKA = 36, - AFRICA_MALABO = 37, - AFRICA_MAPUTO = 38, - AFRICA_MASERU = 39, - AFRICA_MBABANE = 40, - AFRICA_MOGADISHU = 41, - AFRICA_MONROVIA = 42, - AFRICA_NAIROBI = 43, - AFRICA_NDJAMENA = 44, - AFRICA_NIAMEY = 45, - AFRICA_NOUAKCHOTT = 46, - AFRICA_OUAGADOUGOU = 47, - AFRICA_PORTO_NOVO = 48, - AFRICA_SAO_TOME = 49, - LIBYA = 51, - // Delegate. - - AFRICA_TUNIS = 52, - AFRICA_WINDHOEK = 53, - US_ALEUTIAN = 54, - // Delegate. - - US_ALASKA = 55, - // Delegate. - - AMERICA_ANGUILLA = 56, - AMERICA_ANTIGUA = 57, - AMERICA_ARAGUAINA = 58, - AMERICA_BUENOS_AIRES = 59, - AMERICA_CATAMARCA = 60, - AMERICA_CORDOBA = 62, - AMERICA_JUJUY = 63, - AMERICA_ARGENTINA_LA_RIOJA = 64, - AMERICA_MENDOZA = 65, - AMERICA_ARGENTINA_RIO_GALLEGOS = 66, - AMERICA_ARGENTINA_SAN_JUAN = 67, - AMERICA_ARGENTINA_TUCUMAN = 68, - AMERICA_ARGENTINA_USHUAIA = 69, - AMERICA_ARUBA = 70, - AMERICA_ASUNCION = 71, - AMERICA_BAHIA = 72, - AMERICA_BARBADOS = 73, - AMERICA_BELEM = 74, - AMERICA_BELIZE = 75, - AMERICA_BOA_VISTA = 76, - AMERICA_BOGOTA = 77, - AMERICA_BOISE = 78, - AMERICA_CAMBRIDGE_BAY = 79, - AMERICA_CAMPO_GRANDE = 80, - AMERICA_CANCUN = 81, - AMERICA_CARACAS = 82, - AMERICA_CAYENNE = 83, - AMERICA_CAYMAN = 84, - CST6CDT = 85, - // Delegate. - - AMERICA_CHIHUAHUA = 86, - AMERICA_COSTA_RICA = 87, - AMERICA_CUIABA = 88, - AMERICA_CURACAO = 89, - AMERICA_DANMARKSHAVN = 90, - AMERICA_DAWSON = 91, - AMERICA_DAWSON_CREEK = 92, - NAVAJO = 93, - // Delegate. - - US_MICHIGAN = 94, - // Delegate. - - AMERICA_DOMINICA = 95, - CANADA_MOUNTAIN = 96, - // Delegate. - - AMERICA_EIRUNEPE = 97, - AMERICA_EL_SALVADOR = 98, - AMERICA_FORTALEZA = 99, - AMERICA_GLACE_BAY = 100, - AMERICA_GODTHAB = 101, - AMERICA_GOOSE_BAY = 102, - AMERICA_GRAND_TURK = 103, - AMERICA_GRENADA = 104, - AMERICA_GUADELOUPE = 105, - AMERICA_GUATEMALA = 106, - AMERICA_GUAYAQUIL = 107, - AMERICA_GUYANA = 108, - AMERICA_HALIFAX = 109, - // Delegate. - - CUBA = 110, - // Delegate. - - AMERICA_HERMOSILLO = 111, - AMERICA_KNOX_IN = 113, - // Delegate. - - AMERICA_INDIANA_MARENGO = 114, - US_EAST_INDIANA = 115, - AMERICA_INDIANA_VEVAY = 116, - AMERICA_INUVIK = 117, - AMERICA_IQALUIT = 118, - JAMAICA = 119, - // Delegate. - - AMERICA_JUNEAU = 120, - AMERICA_KENTUCKY_MONTICELLO = 122, - AMERICA_LA_PAZ = 123, - AMERICA_LIMA = 124, - AMERICA_LOUISVILLE = 125, - AMERICA_MACEIO = 126, - AMERICA_MANAGUA = 127, - BRAZIL_WEST = 128, - // Delegate. - - AMERICA_MARTINIQUE = 129, - MEXICO_BAJASUR = 130, - // Delegate. - - AMERICA_MENOMINEE = 131, - AMERICA_MERIDA = 132, - MEXICO_GENERAL = 133, - // Delegate. - - AMERICA_MIQUELON = 134, - AMERICA_MONTERREY = 135, - AMERICA_MONTEVIDEO = 136, - AMERICA_MONTREAL = 137, - AMERICA_MONTSERRAT = 138, - AMERICA_NASSAU = 139, - EST5EDT = 140, - // Delegate. - - AMERICA_NIPIGON = 141, - AMERICA_NOME = 142, - AMERICA_NORONHA = 143, - // Delegate. - - AMERICA_NORTH_DAKOTA_CENTER = 144, - AMERICA_PANAMA = 145, - AMERICA_PANGNIRTUNG = 146, - AMERICA_PARAMARIBO = 147, - US_ARIZONA = 148, - // Delegate. - - AMERICA_PORT_AU_PRINCE = 149, - AMERICA_PORT_OF_SPAIN = 150, - AMERICA_PORTO_VELHO = 151, - AMERICA_PUERTO_RICO = 152, - AMERICA_RAINY_RIVER = 153, - AMERICA_RANKIN_INLET = 154, - AMERICA_RECIFE = 155, - AMERICA_REGINA = 156, - // Delegate. - - BRAZIL_ACRE = 157, - AMERICA_SANTIAGO = 158, - // Delegate. - - AMERICA_SANTO_DOMINGO = 159, - BRAZIL_EAST = 160, - // Delegate. - - AMERICA_SCORESBYSUND = 161, - AMERICA_ST_JOHNS = 163, - // Delegate. - - AMERICA_ST_KITTS = 164, - AMERICA_ST_LUCIA = 165, - AMERICA_VIRGIN = 166, - // Delegate. - - AMERICA_ST_VINCENT = 167, - AMERICA_SWIFT_CURRENT = 168, - AMERICA_TEGUCIGALPA = 169, - AMERICA_THULE = 170, - AMERICA_THUNDER_BAY = 171, - AMERICA_TIJUANA = 172, - CANADA_EASTERN = 173, - // Delegate. - - AMERICA_TORTOLA = 174, - CANADA_PACIFIC = 175, - // Delegate. - - CANADA_YUKON = 176, - // Delegate. - - CANADA_CENTRAL = 177, - // Delegate. - - AMERICA_YAKUTAT = 178, - AMERICA_YELLOWKNIFE = 179, - ANTARCTICA_CASEY = 180, - ANTARCTICA_DAVIS = 181, - ANTARCTICA_DUMONTDURVILLE = 182, - ANTARCTICA_MAWSON = 183, - ANTARCTICA_MCMURDO = 184, - ANTARCTICA_PALMER = 185, - ANTARCTICA_ROTHERA = 186, - ANTARCTICA_SYOWA = 188, - ANTARCTICA_VOSTOK = 189, - ATLANTIC_JAN_MAYEN = 190, - // Delegate. - - ASIA_ADEN = 191, - ASIA_ALMATY = 192, - ASIA_AMMAN = 193, - ASIA_ANADYR = 194, - ASIA_AQTAU = 195, - ASIA_AQTOBE = 196, - ASIA_ASHGABAT = 197, - // Delegate. - - ASIA_BAGHDAD = 198, - ASIA_BAHRAIN = 199, - ASIA_BAKU = 200, - ASIA_BANGKOK = 201, - ASIA_BEIRUT = 202, - ASIA_BISHKEK = 203, - ASIA_BRUNEI = 204, - ASIA_KOLKATA = 205, - // Delegate. - - ASIA_CHOIBALSAN = 206, - ASIA_COLOMBO = 208, - ASIA_DAMASCUS = 209, - ASIA_DACCA = 210, - ASIA_DILI = 211, - ASIA_DUBAI = 212, - ASIA_DUSHANBE = 213, - ASIA_GAZA = 214, - HONGKONG = 216, - // Delegate. - - ASIA_HOVD = 217, - ASIA_IRKUTSK = 218, - ASIA_JAKARTA = 220, - ASIA_JAYAPURA = 221, - ISRAEL = 222, - // Delegate. - - ASIA_KABUL = 223, - ASIA_KAMCHATKA = 224, - ASIA_KARACHI = 225, - ASIA_KATMANDU = 227, - ASIA_KRASNOYARSK = 228, - ASIA_KUALA_LUMPUR = 229, - ASIA_KUCHING = 230, - ASIA_KUWAIT = 231, - ASIA_MACAO = 232, - ASIA_MAGADAN = 233, - ASIA_MAKASSAR = 234, - // Delegate. - - ASIA_MANILA = 235, - ASIA_MUSCAT = 236, - ASIA_NICOSIA = 237, - // Delegate. - - ASIA_NOVOSIBIRSK = 238, - ASIA_OMSK = 239, - ASIA_ORAL = 240, - ASIA_PHNOM_PENH = 241, - ASIA_PONTIANAK = 242, - ASIA_PYONGYANG = 243, - ASIA_QATAR = 244, - ASIA_QYZYLORDA = 245, - ASIA_RANGOON = 246, - ASIA_RIYADH = 247, - ASIA_SAIGON = 248, - ASIA_SAKHALIN = 249, - ASIA_SAMARKAND = 250, - ROK = 251, - // Delegate. - - PRC = 252, - SINGAPORE = 253, - // Delegate. - - ROC = 254, - // Delegate. - - ASIA_TASHKENT = 255, - ASIA_TBILISI = 256, - IRAN = 257, - // Delegate. - - ASIA_THIMBU = 258, - JAPAN = 259, - // Delegate. - - ASIA_ULAN_BATOR = 260, - // Delegate. - - ASIA_URUMQI = 261, - ASIA_VIENTIANE = 262, - ASIA_VLADIVOSTOK = 263, - ASIA_YAKUTSK = 264, - ASIA_YEKATERINBURG = 265, - ASIA_YEREVAN = 266, - ATLANTIC_AZORES = 267, - ATLANTIC_BERMUDA = 268, - ATLANTIC_CANARY = 269, - ATLANTIC_CAPE_VERDE = 270, - ATLANTIC_FAROE = 271, - // Delegate. - - ATLANTIC_MADEIRA = 273, - ICELAND = 274, - // Delegate. - - ATLANTIC_SOUTH_GEORGIA = 275, - ATLANTIC_STANLEY = 276, - ATLANTIC_ST_HELENA = 277, - AUSTRALIA_SOUTH = 278, - // Delegate. - - AUSTRALIA_BRISBANE = 279, - // Delegate. - - AUSTRALIA_YANCOWINNA = 280, - // Delegate. - - AUSTRALIA_NORTH = 281, - // Delegate. - - AUSTRALIA_HOBART = 282, - // Delegate. - - AUSTRALIA_LINDEMAN = 283, - AUSTRALIA_LHI = 284, - AUSTRALIA_VICTORIA = 285, - // Delegate. - - AUSTRALIA_WEST = 286, - // Delegate. - - AUSTRALIA_ACT = 287, - EUROPE_AMSTERDAM = 288, - EUROPE_ANDORRA = 289, - EUROPE_ATHENS = 290, - EUROPE_BELGRADE = 292, - EUROPE_BERLIN = 293, - EUROPE_BRATISLAVA = 294, - EUROPE_BRUSSELS = 295, - EUROPE_BUCHAREST = 296, - EUROPE_BUDAPEST = 297, - EUROPE_CHISINAU = 298, - // Delegate. - - EUROPE_COPENHAGEN = 299, - EIRE = 300, - EUROPE_GIBRALTAR = 301, - EUROPE_HELSINKI = 302, - TURKEY = 303, - EUROPE_KALININGRAD = 304, - EUROPE_KIEV = 305, - PORTUGAL = 306, - // Delegate. - - EUROPE_LJUBLJANA = 307, - GB = 308, - EUROPE_LUXEMBOURG = 309, - EUROPE_MADRID = 310, - EUROPE_MALTA = 311, - EUROPE_MARIEHAMN = 312, - EUROPE_MINSK = 313, - EUROPE_MONACO = 314, - W_SU = 315, - // Delegate. - - EUROPE_OSLO = 317, - EUROPE_PARIS = 318, - EUROPE_PRAGUE = 319, - EUROPE_RIGA = 320, - EUROPE_ROME = 321, - EUROPE_SAMARA = 322, - EUROPE_SAN_MARINO = 323, - EUROPE_SARAJEVO = 324, - EUROPE_SIMFEROPOL = 325, - EUROPE_SKOPJE = 326, - EUROPE_SOFIA = 327, - EUROPE_STOCKHOLM = 328, - EUROPE_TALLINN = 329, - EUROPE_TIRANE = 330, - EUROPE_UZHGOROD = 331, - EUROPE_VADUZ = 332, - EUROPE_VATICAN = 333, - EUROPE_VIENNA = 334, - EUROPE_VILNIUS = 335, - POLAND = 336, - // Delegate. - - EUROPE_ZAGREB = 337, - EUROPE_ZAPOROZHYE = 338, - EUROPE_ZURICH = 339, - INDIAN_ANTANANARIVO = 340, - INDIAN_CHAGOS = 341, - INDIAN_CHRISTMAS = 342, - INDIAN_COCOS = 343, - INDIAN_COMORO = 344, - INDIAN_KERGUELEN = 345, - INDIAN_MAHE = 346, - INDIAN_MALDIVES = 347, - INDIAN_MAURITIUS = 348, - INDIAN_MAYOTTE = 349, - INDIAN_REUNION = 350, - PACIFIC_APIA = 351, - NZ = 352, - NZ_CHAT = 353, - PACIFIC_EASTER = 354, - PACIFIC_EFATE = 355, - PACIFIC_ENDERBURY = 356, - PACIFIC_FAKAOFO = 357, - PACIFIC_FIJI = 358, - PACIFIC_FUNAFUTI = 359, - PACIFIC_GALAPAGOS = 360, - PACIFIC_GAMBIER = 361, - PACIFIC_GUADALCANAL = 362, - PACIFIC_GUAM = 363, - US_HAWAII = 364, - // Delegate. - - PACIFIC_JOHNSTON = 365, - PACIFIC_KIRITIMATI = 366, - PACIFIC_KOSRAE = 367, - KWAJALEIN = 368, - PACIFIC_MAJURO = 369, - PACIFIC_MARQUESAS = 370, - PACIFIC_MIDWAY = 371, - PACIFIC_NAURU = 372, - PACIFIC_NIUE = 373, - PACIFIC_NORFOLK = 374, - PACIFIC_NOUMEA = 375, - US_SAMOA = 376, - // Delegate. - - PACIFIC_PALAU = 377, - PACIFIC_PITCAIRN = 378, - PACIFIC_PONAPE = 379, - PACIFIC_PORT_MORESBY = 380, - PACIFIC_RAROTONGA = 381, - PACIFIC_SAIPAN = 382, - PACIFIC_TAHITI = 383, - PACIFIC_TARAWA = 384, - PACIFIC_TONGATAPU = 385, - PACIFIC_YAP = 386, - PACIFIC_WAKE = 387, - PACIFIC_WALLIS = 388, - AMERICA_ATIKOKAN = 390, - AUSTRALIA_CURRIE = 391, - ETC_GMT_EAST_14 = 392, - ETC_GMT_EAST_13 = 393, - ETC_GMT_EAST_12 = 394, - ETC_GMT_EAST_11 = 395, - ETC_GMT_EAST_10 = 396, - ETC_GMT_EAST_9 = 397, - ETC_GMT_EAST_8 = 398, - ETC_GMT_EAST_7 = 399, - ETC_GMT_EAST_6 = 400, - ETC_GMT_EAST_5 = 401, - ETC_GMT_EAST_4 = 402, - ETC_GMT_EAST_3 = 403, - ETC_GMT_EAST_2 = 404, - ETC_GMT_EAST_1 = 405, - GMT = 406, - // Delegate. - - ETC_GMT_WEST_1 = 407, - ETC_GMT_WEST_2 = 408, - ETC_GMT_WEST_3 = 409, - SYSTEMV_AST4 = 410, - // Delegate. - - EST = 411, - SYSTEMV_CST6 = 412, - // Delegate. - - MST = 413, - // Delegate. - - SYSTEMV_PST8 = 414, - // Delegate. - - SYSTEMV_YST9 = 415, - // Delegate. - - HST = 416, - // Delegate. - - ETC_GMT_WEST_11 = 417, - ETC_GMT_WEST_12 = 418, - AMERICA_NORTH_DAKOTA_NEW_SALEM = 419, - AMERICA_INDIANA_PETERSBURG = 420, - AMERICA_INDIANA_VINCENNES = 421, - AMERICA_MONCTON = 422, - AMERICA_BLANC_SABLON = 423, - EUROPE_GUERNSEY = 424, - EUROPE_ISLE_OF_MAN = 425, - EUROPE_JERSEY = 426, - EUROPE_PODGORICA = 427, - EUROPE_VOLGOGRAD = 428, - AMERICA_INDIANA_WINAMAC = 429, - AUSTRALIA_EUCLA = 430, - AMERICA_INDIANA_TELL_CITY = 431, - AMERICA_RESOLUTE = 432, - AMERICA_ARGENTINA_SAN_LUIS = 433, - AMERICA_SANTAREM = 434, - AMERICA_ARGENTINA_SALTA = 435, - AMERICA_BAHIA_BANDERAS = 436, - AMERICA_MARIGOT = 437, - AMERICA_MATAMOROS = 438, - AMERICA_OJINAGA = 439, - AMERICA_SANTA_ISABEL = 440, - AMERICA_ST_BARTHELEMY = 441, - ANTARCTICA_MACQUARIE = 442, - ASIA_NOVOKUZNETSK = 443, - AFRICA_JUBA = 444, - AMERICA_METLAKATLA = 445, - AMERICA_NORTH_DAKOTA_BEULAH = 446, - AMERICA_SITKA = 447, - ASIA_HEBRON = 448, - AMERICA_CRESTON = 449, - AMERICA_KRALENDIJK = 450, - AMERICA_LOWER_PRINCES = 451, - ANTARCTICA_TROLL = 452, - ASIA_KHANDYGA = 453, - ASIA_UST_NERA = 454, - EUROPE_BUSINGEN = 455, - ASIA_CHITA = 456, - ASIA_SREDNEKOLYMSK = 457, -} - diff --git a/annotator/grammar/dates/utils/annotation-keys.cc b/annotator/grammar/dates/utils/annotation-keys.cc deleted file mode 100644 index 659268f..0000000 --- a/annotator/grammar/dates/utils/annotation-keys.cc +++ /dev/null
@@ -1,28 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#include "annotator/grammar/dates/utils/annotation-keys.h" - -namespace libtextclassifier3 { -namespace dates { -const char* const kDateTimeType = "dateTime"; -const char* const kDateTimeRangeType = "dateTimeRange"; -const char* const kDateTime = "dateTime"; -const char* const kDateTimeSupplementary = "dateTimeSupplementary"; -const char* const kDateTimeRelative = "dateTimeRelative"; -const char* const kDateTimeRangeFrom = "dateTimeRangeFrom"; -const char* const kDateTimeRangeTo = "dateTimeRangeTo"; -} // namespace dates -} // namespace libtextclassifier3 diff --git a/annotator/grammar/dates/utils/annotation-keys.h b/annotator/grammar/dates/utils/annotation-keys.h deleted file mode 100644 index 5cddaec..0000000 --- a/annotator/grammar/dates/utils/annotation-keys.h +++ /dev/null
@@ -1,58 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_ANNOTATION_KEYS_H_ -#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_ANNOTATION_KEYS_H_ - -namespace libtextclassifier3 { -namespace dates { - -// Date time specific constants not defined in standard schemas. -// -// Date annotator output two type of annotation. One is date&time like "May 1", -// "12:20pm", etc. Another is range like "2pm - 3pm". The two string identify -// the type of annotation and are used as type in Thing proto. -extern const char* const kDateTimeType; -extern const char* const kDateTimeRangeType; - -// kDateTime contains most common field for date time. It's integer array and -// the format is (year, month, day, hour, minute, second, fraction_sec, -// day_of_week). All eight fields must be provided. If the field is not -// extracted, the value is -1 in the array. -extern const char* const kDateTime; - -// kDateTimeSupplementary contains uncommon field like timespan, timezone. It's -// integer array and the format is (bc_ad, timespan_code, timezone_code, -// timezone_offset). Al four fields must be provided. If the field is not -// extracted, the value is -1 in the array. -extern const char* const kDateTimeSupplementary; - -// kDateTimeRelative contains fields for relative date time. It's integer -// array and the format is (is_future, year, month, day, week, hour, minute, -// second, day_of_week, dow_interpretation*). The first nine fields must be -// provided and dow_interpretation could have zero or multiple values. -// If the field is not extracted, the value is -1 in the array. -extern const char* const kDateTimeRelative; - -// Date time range specific constants not defined in standard schemas. -// kDateTimeRangeFrom and kDateTimeRangeTo define the from/to of a date/time -// range. The value is thing object which contains a date time. -extern const char* const kDateTimeRangeFrom; -extern const char* const kDateTimeRangeTo; - -} // namespace dates -} // namespace libtextclassifier3 - -#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_ANNOTATION_KEYS_H_ diff --git a/annotator/grammar/dates/utils/date-match.cc b/annotator/grammar/dates/utils/date-match.cc deleted file mode 100644 index 5ece2b4..0000000 --- a/annotator/grammar/dates/utils/date-match.cc +++ /dev/null
@@ -1,439 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#include "annotator/grammar/dates/utils/date-match.h" - -#include <algorithm> - -#include "annotator/grammar/dates/utils/date-utils.h" -#include "annotator/types.h" -#include "utils/strings/append.h" - -static const int kAM = 0; -static const int kPM = 1; - -namespace libtextclassifier3 { -namespace dates { - -namespace { -static int GetMeridiemValue(const TimespanCode& timespan_code) { - switch (timespan_code) { - case TimespanCode_AM: - case TimespanCode_MIDNIGHT: - // MIDNIGHT [3] -> AM - return kAM; - case TimespanCode_TONIGHT: - // TONIGHT [11] -> PM - case TimespanCode_NOON: - // NOON [2] -> PM - case TimespanCode_PM: - return kPM; - case TimespanCode_TIMESPAN_CODE_NONE: - default: - TC3_LOG(WARNING) << "Failed to extract time span code."; - } - return NO_VAL; -} - -static int GetRelativeCount(const RelativeParameter* relative_parameter) { - for (const int interpretation : - *relative_parameter->day_of_week_interpretation()) { - switch (interpretation) { - case RelativeParameter_::Interpretation_NEAREST_LAST: - case RelativeParameter_::Interpretation_PREVIOUS: - return -1; - case RelativeParameter_::Interpretation_SECOND_LAST: - return -2; - case RelativeParameter_::Interpretation_SECOND_NEXT: - return 2; - case RelativeParameter_::Interpretation_COMING: - case RelativeParameter_::Interpretation_SOME: - case RelativeParameter_::Interpretation_NEAREST: - case RelativeParameter_::Interpretation_NEAREST_NEXT: - return 1; - case RelativeParameter_::Interpretation_CURRENT: - return 0; - } - } - return 0; -} -} // namespace - -using strings::JoinStrings; -using strings::SStringAppendF; - -std::string DateMatch::DebugString() const { - std::string res; -#if !defined(NDEBUG) - if (begin >= 0 && end >= 0) { - SStringAppendF(&res, 0, "[%u,%u)", begin, end); - } - - if (HasDayOfWeek()) { - SStringAppendF(&res, 0, "%u", day_of_week); - } - - if (HasYear()) { - int year_output = year; - if (HasBcAd() && bc_ad == BCAD_BC) { - year_output = -year; - } - SStringAppendF(&res, 0, "%u/", year_output); - } else { - SStringAppendF(&res, 0, "____/"); - } - - if (HasMonth()) { - SStringAppendF(&res, 0, "%u/", month); - } else { - SStringAppendF(&res, 0, "__/"); - } - - if (HasDay()) { - SStringAppendF(&res, 0, "%u ", day); - } else { - SStringAppendF(&res, 0, "__ "); - } - - if (HasHour()) { - SStringAppendF(&res, 0, "%u:", hour); - } else { - SStringAppendF(&res, 0, "__:"); - } - - if (HasMinute()) { - SStringAppendF(&res, 0, "%u:", minute); - } else { - SStringAppendF(&res, 0, "__:"); - } - - if (HasSecond()) { - if (HasFractionSecond()) { - SStringAppendF(&res, 0, "%u.%lf ", second, fraction_second); - } else { - SStringAppendF(&res, 0, "%u ", second); - } - } else { - SStringAppendF(&res, 0, "__ "); - } - - if (HasTimeSpanCode() && TimespanCode_TIMESPAN_CODE_NONE < time_span_code && - time_span_code <= TimespanCode_MAX) { - SStringAppendF(&res, 0, "TS=%u ", time_span_code); - } - - if (HasTimeZoneCode() && time_zone_code != -1) { - SStringAppendF(&res, 0, "TZ= %u ", time_zone_code); - } - - if (HasTimeZoneOffset()) { - SStringAppendF(&res, 0, "TZO=%u ", time_zone_offset); - } - - if (HasRelativeDate()) { - const RelativeMatch* rm = relative_match; - SStringAppendF(&res, 0, (rm->is_future_date ? "future " : "past ")); - if (rm->day_of_week != NO_VAL) { - SStringAppendF(&res, 0, "DOW:%d ", rm->day_of_week); - } - if (rm->year != NO_VAL) { - SStringAppendF(&res, 0, "Y:%d ", rm->year); - } - if (rm->month != NO_VAL) { - SStringAppendF(&res, 0, "M:%d ", rm->month); - } - if (rm->day != NO_VAL) { - SStringAppendF(&res, 0, "D:%d ", rm->day); - } - if (rm->week != NO_VAL) { - SStringAppendF(&res, 0, "W:%d ", rm->week); - } - if (rm->hour != NO_VAL) { - SStringAppendF(&res, 0, "H:%d ", rm->hour); - } - if (rm->minute != NO_VAL) { - SStringAppendF(&res, 0, "M:%d ", rm->minute); - } - if (rm->second != NO_VAL) { - SStringAppendF(&res, 0, "S:%d ", rm->second); - } - } - - SStringAppendF(&res, 0, "prio=%d ", priority); - SStringAppendF(&res, 0, "conf-score=%lf ", annotator_priority_score); - - if (IsHourAmbiguous()) { - std::vector<int8> values; - GetPossibleHourValues(&values); - std::string str_values; - - for (unsigned int i = 0; i < values.size(); ++i) { - SStringAppendF(&str_values, 0, "%u,", values[i]); - } - SStringAppendF(&res, 0, "amb=%s ", str_values.c_str()); - } - - std::vector<std::string> tags; - if (is_inferred) { - tags.push_back("inferred"); - } - if (!tags.empty()) { - SStringAppendF(&res, 0, "tag=%s ", JoinStrings(",", tags).c_str()); - } -#endif // !defined(NDEBUG) - return res; -} - -void DateMatch::GetPossibleHourValues(std::vector<int8>* values) const { - TC3_CHECK(values != nullptr); - values->clear(); - if (HasHour()) { - int8 possible_hour = hour; - values->push_back(possible_hour); - for (int count = 1; count < ambiguous_hour_count; ++count) { - possible_hour += ambiguous_hour_interval; - if (possible_hour >= 24) { - possible_hour -= 24; - } - values->push_back(possible_hour); - } - } -} - -DatetimeComponent::RelativeQualifier DateMatch::GetRelativeQualifier() const { - if (HasRelativeDate()) { - if (relative_match->existing & RelativeMatch::HAS_IS_FUTURE) { - if (!relative_match->is_future_date) { - return DatetimeComponent::RelativeQualifier::PAST; - } - } - return DatetimeComponent::RelativeQualifier::FUTURE; - } - return DatetimeComponent::RelativeQualifier::UNSPECIFIED; -} - -// Embed RelativeQualifier information of DatetimeComponent as a sign of -// relative counter field of datetime component i.e. relative counter is -// negative when relative qualifier RelativeQualifier::PAST. -int GetAdjustedRelativeCounter( - const DatetimeComponent::RelativeQualifier& relative_qualifier, - const int relative_counter) { - if (DatetimeComponent::RelativeQualifier::PAST == relative_qualifier) { - return -relative_counter; - } - return relative_counter; -} - -Optional<DatetimeComponent> CreateDatetimeComponent( - const DatetimeComponent::ComponentType& component_type, - const DatetimeComponent::RelativeQualifier& relative_qualifier, - const int absolute_value, const int relative_value) { - if (absolute_value == NO_VAL && relative_value == NO_VAL) { - return Optional<DatetimeComponent>(); - } - return Optional<DatetimeComponent>(DatetimeComponent( - component_type, - (relative_value != NO_VAL) - ? relative_qualifier - : DatetimeComponent::RelativeQualifier::UNSPECIFIED, - (absolute_value != NO_VAL) ? absolute_value : 0, - (relative_value != NO_VAL) - ? GetAdjustedRelativeCounter(relative_qualifier, relative_value) - : 0)); -} - -Optional<DatetimeComponent> CreateDayOfWeekComponent( - const RelativeMatch* relative_match, - const DatetimeComponent::RelativeQualifier& relative_qualifier, - const DayOfWeek& absolute_day_of_week) { - DatetimeComponent::RelativeQualifier updated_relative_qualifier = - relative_qualifier; - int absolute_value = absolute_day_of_week; - int relative_value = NO_VAL; - if (relative_match) { - relative_value = relative_match->day_of_week; - if (relative_match->existing & RelativeMatch::HAS_DAY_OF_WEEK) { - if (relative_match->IsStandaloneRelativeDayOfWeek() && - absolute_day_of_week == DayOfWeek_DOW_NONE) { - absolute_value = relative_match->day_of_week; - } - // Check if the relative date has day of week with week period. - if (relative_match->existing & RelativeMatch::HAS_WEEK) { - relative_value = 1; - } else { - const NonterminalValue* nonterminal = - relative_match->day_of_week_nonterminal; - TC3_CHECK(nonterminal != nullptr); - TC3_CHECK(nonterminal->relative_parameter()); - const RelativeParameter* rp = nonterminal->relative_parameter(); - if (rp->day_of_week_interpretation()) { - relative_value = GetRelativeCount(rp); - if (relative_value < 0) { - relative_value = abs(relative_value); - updated_relative_qualifier = - DatetimeComponent::RelativeQualifier::PAST; - } else if (relative_value > 0) { - updated_relative_qualifier = - DatetimeComponent::RelativeQualifier::FUTURE; - } - } - } - } - } - return CreateDatetimeComponent(DatetimeComponent::ComponentType::DAY_OF_WEEK, - updated_relative_qualifier, absolute_value, - relative_value); -} - -// Resolve the year’s ambiguity. -// If the year in the date has 4 digits i.e. DD/MM/YYYY then there is no -// ambiguity, the year value is YYYY but certain format i.e. MM/DD/YY is -// ambiguous e.g. in {April/23/15} year value can be 15 or 1915 or 2015. -// Following heuristic is used to resolve the ambiguity. -// - For YYYY there is nothing to resolve. -// - For all YY years -// - Value less than 50 will be resolved to 20YY -// - Value greater or equal 50 will be resolved to 19YY -static int InterpretYear(int parsed_year) { - if (parsed_year == NO_VAL) { - return parsed_year; - } - if (parsed_year < 100) { - if (parsed_year < 50) { - return parsed_year + 2000; - } - return parsed_year + 1900; - } - return parsed_year; -} - -Optional<DatetimeComponent> DateMatch::GetDatetimeComponent( - const DatetimeComponent::ComponentType& component_type) const { - switch (component_type) { - case DatetimeComponent::ComponentType::YEAR: - return CreateDatetimeComponent( - component_type, GetRelativeQualifier(), InterpretYear(year), - (relative_match != nullptr) ? relative_match->year : NO_VAL); - case DatetimeComponent::ComponentType::MONTH: - return CreateDatetimeComponent( - component_type, GetRelativeQualifier(), month, - (relative_match != nullptr) ? relative_match->month : NO_VAL); - case DatetimeComponent::ComponentType::DAY_OF_MONTH: - return CreateDatetimeComponent( - component_type, GetRelativeQualifier(), day, - (relative_match != nullptr) ? relative_match->day : NO_VAL); - case DatetimeComponent::ComponentType::HOUR: - return CreateDatetimeComponent( - component_type, GetRelativeQualifier(), hour, - (relative_match != nullptr) ? relative_match->hour : NO_VAL); - case DatetimeComponent::ComponentType::MINUTE: - return CreateDatetimeComponent( - component_type, GetRelativeQualifier(), minute, - (relative_match != nullptr) ? relative_match->minute : NO_VAL); - case DatetimeComponent::ComponentType::SECOND: - return CreateDatetimeComponent( - component_type, GetRelativeQualifier(), second, - (relative_match != nullptr) ? relative_match->second : NO_VAL); - case DatetimeComponent::ComponentType::DAY_OF_WEEK: - return CreateDayOfWeekComponent(relative_match, GetRelativeQualifier(), - day_of_week); - case DatetimeComponent::ComponentType::MERIDIEM: - return CreateDatetimeComponent(component_type, GetRelativeQualifier(), - GetMeridiemValue(time_span_code), NO_VAL); - case DatetimeComponent::ComponentType::ZONE_OFFSET: - if (HasTimeZoneOffset()) { - return Optional<DatetimeComponent>(DatetimeComponent( - component_type, DatetimeComponent::RelativeQualifier::UNSPECIFIED, - time_zone_offset, /*arg_relative_count=*/0)); - } - return Optional<DatetimeComponent>(); - case DatetimeComponent::ComponentType::WEEK: - return CreateDatetimeComponent( - component_type, GetRelativeQualifier(), NO_VAL, - HasRelativeDate() ? relative_match->week : NO_VAL); - default: - return Optional<DatetimeComponent>(); - } -} - -bool DateMatch::IsValid() const { - if (!HasYear() && HasBcAd()) { - return false; - } - if (!HasMonth() && HasYear() && (HasDay() || HasDayOfWeek())) { - return false; - } - if (!HasDay() && HasDayOfWeek() && (HasYear() || HasMonth())) { - return false; - } - if (!HasDay() && !HasDayOfWeek() && HasHour() && (HasYear() || HasMonth())) { - return false; - } - if (!HasHour() && (HasMinute() || HasSecond() || HasFractionSecond())) { - return false; - } - if (!HasMinute() && (HasSecond() || HasFractionSecond())) { - return false; - } - if (!HasSecond() && HasFractionSecond()) { - return false; - } - // Check whether day exists in a month, to exclude cases like "April 31". - if (HasDay() && HasMonth() && day > GetLastDayOfMonth(year, month)) { - return false; - } - return (HasDateFields() || HasTimeFields() || HasRelativeDate()); -} - -void DateMatch::FillDatetimeComponents( - std::vector<DatetimeComponent>* datetime_component) const { - static const std::vector<DatetimeComponent::ComponentType>* - kDatetimeComponents = new std::vector<DatetimeComponent::ComponentType>{ - DatetimeComponent::ComponentType::ZONE_OFFSET, - DatetimeComponent::ComponentType::MERIDIEM, - DatetimeComponent::ComponentType::SECOND, - DatetimeComponent::ComponentType::MINUTE, - DatetimeComponent::ComponentType::HOUR, - DatetimeComponent::ComponentType::DAY_OF_MONTH, - DatetimeComponent::ComponentType::DAY_OF_WEEK, - DatetimeComponent::ComponentType::WEEK, - DatetimeComponent::ComponentType::MONTH, - DatetimeComponent::ComponentType::YEAR}; - - for (const DatetimeComponent::ComponentType& component_type : - *kDatetimeComponents) { - Optional<DatetimeComponent> date_time = - GetDatetimeComponent(component_type); - if (date_time.has_value()) { - datetime_component->emplace_back(date_time.value()); - } - } -} - -std::string DateRangeMatch::DebugString() const { - std::string res; - // The method is only called for debugging purposes. -#if !defined(NDEBUG) - if (begin >= 0 && end >= 0) { - SStringAppendF(&res, 0, "[%u,%u)\n", begin, end); - } - SStringAppendF(&res, 0, "from: %s \n", from.DebugString().c_str()); - SStringAppendF(&res, 0, "to: %s\n", to.DebugString().c_str()); -#endif // !defined(NDEBUG) - return res; -} - -} // namespace dates -} // namespace libtextclassifier3 diff --git a/annotator/grammar/dates/utils/date-match.h b/annotator/grammar/dates/utils/date-match.h deleted file mode 100644 index 5e87cf2..0000000 --- a/annotator/grammar/dates/utils/date-match.h +++ /dev/null
@@ -1,536 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_MATCH_H_ -#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_MATCH_H_ - -#include <stddef.h> -#include <stdint.h> - -#include <algorithm> -#include <vector> - -#include "annotator/grammar/dates/dates_generated.h" -#include "annotator/grammar/dates/timezone-code_generated.h" -#include "utils/grammar/match.h" - -namespace libtextclassifier3 { -namespace dates { - -static constexpr int NO_VAL = -1; - -// POD match data structure. -struct MatchBase : public grammar::Match { - void Reset() { type = MatchType::MatchType_UNKNOWN; } -}; - -struct ExtractionMatch : public MatchBase { - const ExtractionRuleParameter* extraction_rule; - - void Reset() { - MatchBase::Reset(); - type = MatchType::MatchType_DATETIME_RULE; - extraction_rule = nullptr; - } -}; - -struct TermValueMatch : public MatchBase { - const TermValue* term_value; - - void Reset() { - MatchBase::Reset(); - type = MatchType::MatchType_TERM_VALUE; - term_value = nullptr; - } -}; - -struct NonterminalMatch : public MatchBase { - const NonterminalValue* nonterminal; - - void Reset() { - MatchBase::Reset(); - type = MatchType::MatchType_NONTERMINAL; - nonterminal = nullptr; - } -}; - -struct IntegerMatch : public NonterminalMatch { - int value; - int8 count_of_digits; // When expression is in digits format. - bool is_zero_prefixed; // When expression is in digits format. - - void Reset() { - NonterminalMatch::Reset(); - value = NO_VAL; - count_of_digits = 0; - is_zero_prefixed = false; - } -}; - -struct DigitsMatch : public IntegerMatch { - void Reset() { - IntegerMatch::Reset(); - type = MatchType::MatchType_DIGITS; - } - - static bool IsValid(int x) { return true; } -}; - -struct YearMatch : public IntegerMatch { - void Reset() { - IntegerMatch::Reset(); - type = MatchType::MatchType_YEAR; - } - - static bool IsValid(int x) { return x >= 1; } -}; - -struct MonthMatch : public IntegerMatch { - void Reset() { - IntegerMatch::Reset(); - type = MatchType::MatchType_MONTH; - } - - static bool IsValid(int x) { return (x >= 1 && x <= 12); } -}; - -struct DayMatch : public IntegerMatch { - void Reset() { - IntegerMatch::Reset(); - type = MatchType::MatchType_DAY; - } - - static bool IsValid(int x) { return (x >= 1 && x <= 31); } -}; - -struct HourMatch : public IntegerMatch { - void Reset() { - IntegerMatch::Reset(); - type = MatchType::MatchType_HOUR; - } - - static bool IsValid(int x) { return (x >= 0 && x <= 24); } -}; - -struct MinuteMatch : public IntegerMatch { - void Reset() { - IntegerMatch::Reset(); - type = MatchType::MatchType_MINUTE; - } - - static bool IsValid(int x) { return (x >= 0 && x <= 59); } -}; - -struct SecondMatch : public IntegerMatch { - void Reset() { - IntegerMatch::Reset(); - type = MatchType::MatchType_SECOND; - } - - static bool IsValid(int x) { return (x >= 0 && x <= 60); } -}; - -struct DecimalMatch : public NonterminalMatch { - double value; - int8 count_of_digits; // When expression is in digits format. - - void Reset() { - NonterminalMatch::Reset(); - value = NO_VAL; - count_of_digits = 0; - } -}; - -struct FractionSecondMatch : public DecimalMatch { - void Reset() { - DecimalMatch::Reset(); - type = MatchType::MatchType_FRACTION_SECOND; - } - - static bool IsValid(double x) { return (x >= 0.0 && x < 1.0); } -}; - -// CombinedIntegersMatch<N> is used for expressions containing multiple (up -// to N) matches of integers without delimeters between them (because -// CFG-grammar is based on tokenizer, it could not split a token into several -// pieces like using regular-expression). For example, "1130" contains "11" -// and "30" meaning November 30. -template <int N> -struct CombinedIntegersMatch : public NonterminalMatch { - enum { - SIZE = N, - }; - - int values[SIZE]; - int8 count_of_digits; // When expression is in digits format. - bool is_zero_prefixed; // When expression is in digits format. - - void Reset() { - NonterminalMatch::Reset(); - for (int i = 0; i < SIZE; ++i) { - values[i] = NO_VAL; - } - count_of_digits = 0; - is_zero_prefixed = false; - } -}; - -struct CombinedDigitsMatch : public CombinedIntegersMatch<6> { - enum Index { - INDEX_YEAR = 0, - INDEX_MONTH = 1, - INDEX_DAY = 2, - INDEX_HOUR = 3, - INDEX_MINUTE = 4, - INDEX_SECOND = 5, - }; - - bool HasYear() const { return values[INDEX_YEAR] != NO_VAL; } - bool HasMonth() const { return values[INDEX_MONTH] != NO_VAL; } - bool HasDay() const { return values[INDEX_DAY] != NO_VAL; } - bool HasHour() const { return values[INDEX_HOUR] != NO_VAL; } - bool HasMinute() const { return values[INDEX_MINUTE] != NO_VAL; } - bool HasSecond() const { return values[INDEX_SECOND] != NO_VAL; } - - int GetYear() const { return values[INDEX_YEAR]; } - int GetMonth() const { return values[INDEX_MONTH]; } - int GetDay() const { return values[INDEX_DAY]; } - int GetHour() const { return values[INDEX_HOUR]; } - int GetMinute() const { return values[INDEX_MINUTE]; } - int GetSecond() const { return values[INDEX_SECOND]; } - - void Reset() { - CombinedIntegersMatch<SIZE>::Reset(); - type = MatchType::MatchType_COMBINED_DIGITS; - } - - static bool IsValid(int i, int x) { - switch (i) { - case INDEX_YEAR: - return YearMatch::IsValid(x); - case INDEX_MONTH: - return MonthMatch::IsValid(x); - case INDEX_DAY: - return DayMatch::IsValid(x); - case INDEX_HOUR: - return HourMatch::IsValid(x); - case INDEX_MINUTE: - return MinuteMatch::IsValid(x); - case INDEX_SECOND: - return SecondMatch::IsValid(x); - default: - return false; - } - } -}; - -struct TimeValueMatch : public NonterminalMatch { - const HourMatch* hour_match; - const MinuteMatch* minute_match; - const SecondMatch* second_match; - const FractionSecondMatch* fraction_second_match; - - bool is_hour_zero_prefixed : 1; - bool is_minute_one_digit : 1; - bool is_second_one_digit : 1; - - int8 hour; - int8 minute; - int8 second; - double fraction_second; - - void Reset() { - NonterminalMatch::Reset(); - type = MatchType::MatchType_TIME_VALUE; - hour_match = nullptr; - minute_match = nullptr; - second_match = nullptr; - fraction_second_match = nullptr; - is_hour_zero_prefixed = false; - is_minute_one_digit = false; - is_second_one_digit = false; - hour = NO_VAL; - minute = NO_VAL; - second = NO_VAL; - fraction_second = NO_VAL; - } -}; - -struct TimeSpanMatch : public NonterminalMatch { - const TimeSpanSpec* time_span_spec; - TimespanCode time_span_code; - - void Reset() { - NonterminalMatch::Reset(); - type = MatchType::MatchType_TIME_SPAN; - time_span_spec = nullptr; - time_span_code = TimespanCode_TIMESPAN_CODE_NONE; - } -}; - -struct TimeZoneNameMatch : public NonterminalMatch { - const TimeZoneNameSpec* time_zone_name_spec; - TimezoneCode time_zone_code; - - void Reset() { - NonterminalMatch::Reset(); - type = MatchType::MatchType_TIME_ZONE_NAME; - time_zone_name_spec = nullptr; - time_zone_code = TimezoneCode_TIMEZONE_CODE_NONE; - } -}; - -struct TimeZoneOffsetMatch : public NonterminalMatch { - const TimeZoneOffsetParameter* time_zone_offset_param; - int16 time_zone_offset; - - void Reset() { - NonterminalMatch::Reset(); - type = MatchType::MatchType_TIME_ZONE_OFFSET; - time_zone_offset_param = nullptr; - time_zone_offset = 0; - } -}; - -struct DayOfWeekMatch : public IntegerMatch { - void Reset() { - IntegerMatch::Reset(); - type = MatchType::MatchType_DAY_OF_WEEK; - } - - static bool IsValid(int x) { - return (x > DayOfWeek_DOW_NONE && x <= DayOfWeek_MAX); - } -}; - -struct TimePeriodMatch : public NonterminalMatch { - int value; - - void Reset() { - NonterminalMatch::Reset(); - type = MatchType::MatchType_TIME_PERIOD; - value = NO_VAL; - } -}; - -struct RelativeMatch : public NonterminalMatch { - enum { - HAS_NONE = 0, - HAS_YEAR = 1 << 0, - HAS_MONTH = 1 << 1, - HAS_DAY = 1 << 2, - HAS_WEEK = 1 << 3, - HAS_HOUR = 1 << 4, - HAS_MINUTE = 1 << 5, - HAS_SECOND = 1 << 6, - HAS_DAY_OF_WEEK = 1 << 7, - HAS_IS_FUTURE = 1 << 31, - }; - uint32 existing; - - int year; - int month; - int day; - int week; - int hour; - int minute; - int second; - const NonterminalValue* day_of_week_nonterminal; - int8 day_of_week; - bool is_future_date; - - bool HasDay() const { return existing & HAS_DAY; } - - bool HasDayFields() const { return existing & (HAS_DAY | HAS_DAY_OF_WEEK); } - - bool HasTimeValueFields() const { - return existing & (HAS_HOUR | HAS_MINUTE | HAS_SECOND); - } - - bool IsStandaloneRelativeDayOfWeek() const { - return (existing & HAS_DAY_OF_WEEK) && (existing & ~HAS_DAY_OF_WEEK) == 0; - } - - void Reset() { - NonterminalMatch::Reset(); - type = MatchType::MatchType_RELATIVE_DATE; - existing = HAS_NONE; - year = NO_VAL; - month = NO_VAL; - day = NO_VAL; - week = NO_VAL; - hour = NO_VAL; - minute = NO_VAL; - second = NO_VAL; - day_of_week = NO_VAL; - is_future_date = false; - } -}; - -// This is not necessarily POD, it is used to keep the final matched result. -struct DateMatch { - // Sub-matches in the date match. - const YearMatch* year_match = nullptr; - const MonthMatch* month_match = nullptr; - const DayMatch* day_match = nullptr; - const DayOfWeekMatch* day_of_week_match = nullptr; - const TimeValueMatch* time_value_match = nullptr; - const TimeSpanMatch* time_span_match = nullptr; - const TimeZoneNameMatch* time_zone_name_match = nullptr; - const TimeZoneOffsetMatch* time_zone_offset_match = nullptr; - const RelativeMatch* relative_match = nullptr; - const CombinedDigitsMatch* combined_digits_match = nullptr; - - // [begin, end) indicates the Document position where the date or date range - // was found. - int begin = -1; - int end = -1; - int priority = 0; - float annotator_priority_score = 0.0; - - int year = NO_VAL; - int8 month = NO_VAL; - int8 day = NO_VAL; - DayOfWeek day_of_week = DayOfWeek_DOW_NONE; - BCAD bc_ad = BCAD_BCAD_NONE; - int8 hour = NO_VAL; - int8 minute = NO_VAL; - int8 second = NO_VAL; - double fraction_second = NO_VAL; - TimespanCode time_span_code = TimespanCode_TIMESPAN_CODE_NONE; - int time_zone_code = TimezoneCode_TIMEZONE_CODE_NONE; - int16 time_zone_offset = std::numeric_limits<int16>::min(); - - // Fields about ambiguous hours. These fields are used to interpret the - // possible values of ambiguous hours. Since all kinds of known ambiguities - // are in the form of arithmetic progression (starting from .hour field), - // we can use "ambiguous_hour_count" to denote the count of ambiguous hours, - // and use "ambiguous_hour_interval" to denote the distance between a pair - // of adjacent possible hours. Values in the arithmetic progression are - // shrunk into [0, 23] (MOD 24). One can use the GetPossibleHourValues() - // method for the complete list of possible hours. - uint8 ambiguous_hour_count = 0; - uint8 ambiguous_hour_interval = 0; - - bool is_inferred = false; - - // This field is set in function PerformRefinements to remove some DateMatch - // like overlapped, duplicated, etc. - bool is_removed = false; - - std::string DebugString() const; - - bool HasYear() const { return year != NO_VAL; } - bool HasMonth() const { return month != NO_VAL; } - bool HasDay() const { return day != NO_VAL; } - bool HasDayOfWeek() const { return day_of_week != DayOfWeek_DOW_NONE; } - bool HasBcAd() const { return bc_ad != BCAD_BCAD_NONE; } - bool HasHour() const { return hour != NO_VAL; } - bool HasMinute() const { return minute != NO_VAL; } - bool HasSecond() const { return second != NO_VAL; } - bool HasFractionSecond() const { return fraction_second != NO_VAL; } - bool HasTimeSpanCode() const { - return time_span_code != TimespanCode_TIMESPAN_CODE_NONE; - } - bool HasTimeZoneCode() const { - return time_zone_code != TimezoneCode_TIMEZONE_CODE_NONE; - } - bool HasTimeZoneOffset() const { - return time_zone_offset != std::numeric_limits<int16>::min(); - } - - bool HasRelativeDate() const { return relative_match != nullptr; } - - bool IsHourAmbiguous() const { return ambiguous_hour_count >= 2; } - - bool IsStandaloneTime() const { - return (HasHour() || HasMinute()) && !HasDayOfWeek() && !HasDay() && - !HasMonth() && !HasYear(); - } - - void SetAmbiguousHourProperties(uint8 count, uint8 interval) { - ambiguous_hour_count = count; - ambiguous_hour_interval = interval; - } - - // Outputs all the possible hour values. If current DateMatch does not - // contain an hour, nothing will be output. If the hour is not ambiguous, - // only one value (= .hour) will be output. This method clears the vector - // "values" first, and it is not guaranteed that the values in the vector - // are in a sorted order. - void GetPossibleHourValues(std::vector<int8>* values) const; - - int GetPriority() const { return priority; } - - float GetAnnotatorPriorityScore() const { return annotator_priority_score; } - - bool IsStandaloneRelativeDayOfWeek() const { - return (HasRelativeDate() && - relative_match->IsStandaloneRelativeDayOfWeek() && - !HasDateFields() && !HasTimeFields() && !HasTimeSpanCode()); - } - - bool HasDateFields() const { - return (HasYear() || HasMonth() || HasDay() || HasDayOfWeek() || HasBcAd()); - } - bool HasTimeValueFields() const { - return (HasHour() || HasMinute() || HasSecond() || HasFractionSecond()); - } - bool HasTimeSpanFields() const { return HasTimeSpanCode(); } - bool HasTimeZoneFields() const { - return (HasTimeZoneCode() || HasTimeZoneOffset()); - } - bool HasTimeFields() const { - return (HasTimeValueFields() || HasTimeSpanFields() || HasTimeZoneFields()); - } - - bool IsValid() const; - - // Overall relative qualifier of the DateMatch e.g. 2 year ago is 'PAST' and - // next week is 'FUTURE'. - DatetimeComponent::RelativeQualifier GetRelativeQualifier() const; - - // Getter method to get the 'DatetimeComponent' of given 'ComponentType'. - Optional<DatetimeComponent> GetDatetimeComponent( - const DatetimeComponent::ComponentType& component_type) const; - - void FillDatetimeComponents( - std::vector<DatetimeComponent>* datetime_component) const; -}; - -// Represent a matched date range which includes the from and to matched date. -struct DateRangeMatch { - int begin = -1; - int end = -1; - - DateMatch from; - DateMatch to; - - std::string DebugString() const; - - int GetPriority() const { - return std::max(from.GetPriority(), to.GetPriority()); - } - - float GetAnnotatorPriorityScore() const { - return std::max(from.GetAnnotatorPriorityScore(), - to.GetAnnotatorPriorityScore()); - } -}; - -} // namespace dates -} // namespace libtextclassifier3 - -#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_MATCH_H_ diff --git a/annotator/grammar/dates/utils/date-utils.cc b/annotator/grammar/dates/utils/date-utils.cc deleted file mode 100644 index 94552fd..0000000 --- a/annotator/grammar/dates/utils/date-utils.cc +++ /dev/null
@@ -1,400 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#pragma GCC diagnostic ignored "-Wunused-function" - -#include "annotator/grammar/dates/utils/date-utils.h" - -#include <algorithm> -#include <ctime> - -#include "annotator/grammar/dates/annotations/annotation-util.h" -#include "annotator/grammar/dates/dates_generated.h" -#include "annotator/grammar/dates/utils/annotation-keys.h" -#include "annotator/grammar/dates/utils/date-match.h" -#include "annotator/types.h" -#include "utils/base/macros.h" - -namespace libtextclassifier3 { -namespace dates { - -bool IsLeapYear(int year) { - // For the sake of completeness, we want to be able to decide - // whether a year is a leap year all the way back to 0 Julian, or - // 4714 BCE. But we don't want to take the modulus of a negative - // number, because this may not be very well-defined or portable. So - // we increment the year by some large multiple of 400, which is the - // periodicity of this leap-year calculation. - if (year < 0) { - year += 8000; - } - return ((year) % 4 == 0 && ((year) % 100 != 0 || (year) % 400 == 0)); -} - -namespace { -#define SECSPERMIN (60) -#define MINSPERHOUR (60) -#define HOURSPERDAY (24) -#define DAYSPERWEEK (7) -#define DAYSPERNYEAR (365) -#define DAYSPERLYEAR (366) -#define MONSPERYEAR (12) - -const int8 kDaysPerMonth[2][1 + MONSPERYEAR] = { - {-1, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}, - {-1, 31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}, -}; -} // namespace - -int8 GetLastDayOfMonth(int year, int month) { - if (year == 0) { // No year specified - return kDaysPerMonth[1][month]; - } - return kDaysPerMonth[IsLeapYear(year)][month]; -} - -namespace { -inline bool IsHourInSegment(const TimeSpanSpec_::Segment* segment, int8 hour, - bool is_exact) { - return (hour >= segment->begin() && - (hour < segment->end() || - (hour == segment->end() && is_exact && segment->is_closed()))); -} - -Property* FindOrCreateDefaultDateTime(AnnotationData* inst) { - // Refer comments for kDateTime in annotation-keys.h to see the format. - static constexpr int kDefault[] = {-1, -1, -1, -1, -1, -1, -1, -1}; - - int idx = GetPropertyIndex(kDateTime, *inst); - if (idx < 0) { - idx = AddRepeatedIntProperty(kDateTime, kDefault, TC3_ARRAYSIZE(kDefault), - inst); - } - return &inst->properties[idx]; -} - -void IncrementDayOfWeek(DayOfWeek* dow) { - static const DayOfWeek dow_ring[] = {DayOfWeek_MONDAY, DayOfWeek_TUESDAY, - DayOfWeek_WEDNESDAY, DayOfWeek_THURSDAY, - DayOfWeek_FRIDAY, DayOfWeek_SATURDAY, - DayOfWeek_SUNDAY, DayOfWeek_MONDAY}; - const auto& cur_dow = - std::find(std::begin(dow_ring), std::end(dow_ring), *dow); - if (cur_dow != std::end(dow_ring)) { - *dow = *std::next(cur_dow); - } -} -} // namespace - -bool NormalizeHourByTimeSpan(const TimeSpanSpec* ts_spec, DateMatch* date) { - if (ts_spec->segment() == nullptr) { - return false; - } - if (date->HasHour()) { - const bool is_exact = - (!date->HasMinute() || - (date->minute == 0 && - (!date->HasSecond() || - (date->second == 0 && - (!date->HasFractionSecond() || date->fraction_second == 0.0))))); - for (const TimeSpanSpec_::Segment* segment : *ts_spec->segment()) { - if (IsHourInSegment(segment, date->hour + segment->offset(), is_exact)) { - date->hour += segment->offset(); - return true; - } - if (!segment->is_strict() && - IsHourInSegment(segment, date->hour, is_exact)) { - return true; - } - } - } else { - for (const TimeSpanSpec_::Segment* segment : *ts_spec->segment()) { - if (segment->is_stand_alone()) { - if (segment->begin() == segment->end()) { - date->hour = segment->begin(); - } - // Allow stand-alone time-span points and ranges. - return true; - } - } - } - return false; -} - -bool IsRefinement(const DateMatch& a, const DateMatch& b) { - int count = 0; - if (b.HasBcAd()) { - if (!a.HasBcAd() || a.bc_ad != b.bc_ad) return false; - } else if (a.HasBcAd()) { - if (a.bc_ad == BCAD_BC) return false; - ++count; - } - if (b.HasYear()) { - if (!a.HasYear() || a.year != b.year) return false; - } else if (a.HasYear()) { - ++count; - } - if (b.HasMonth()) { - if (!a.HasMonth() || a.month != b.month) return false; - } else if (a.HasMonth()) { - ++count; - } - if (b.HasDay()) { - if (!a.HasDay() || a.day != b.day) return false; - } else if (a.HasDay()) { - ++count; - } - if (b.HasDayOfWeek()) { - if (!a.HasDayOfWeek() || a.day_of_week != b.day_of_week) return false; - } else if (a.HasDayOfWeek()) { - ++count; - } - if (b.HasHour()) { - if (!a.HasHour()) return false; - std::vector<int8> possible_hours; - b.GetPossibleHourValues(&possible_hours); - if (std::find(possible_hours.begin(), possible_hours.end(), a.hour) == - possible_hours.end()) { - return false; - } - } else if (a.HasHour()) { - ++count; - } - if (b.HasMinute()) { - if (!a.HasMinute() || a.minute != b.minute) return false; - } else if (a.HasMinute()) { - ++count; - } - if (b.HasSecond()) { - if (!a.HasSecond() || a.second != b.second) return false; - } else if (a.HasSecond()) { - ++count; - } - if (b.HasFractionSecond()) { - if (!a.HasFractionSecond() || a.fraction_second != b.fraction_second) - return false; - } else if (a.HasFractionSecond()) { - ++count; - } - if (b.HasTimeSpanCode()) { - if (!a.HasTimeSpanCode() || a.time_span_code != b.time_span_code) - return false; - } else if (a.HasTimeSpanCode()) { - ++count; - } - if (b.HasTimeZoneCode()) { - if (!a.HasTimeZoneCode() || a.time_zone_code != b.time_zone_code) - return false; - } else if (a.HasTimeZoneCode()) { - ++count; - } - if (b.HasTimeZoneOffset()) { - if (!a.HasTimeZoneOffset() || a.time_zone_offset != b.time_zone_offset) - return false; - } else if (a.HasTimeZoneOffset()) { - ++count; - } - return (count > 0 || a.priority >= b.priority); -} - -bool IsRefinement(const DateRangeMatch& a, const DateRangeMatch& b) { - return false; -} - -bool IsPrecedent(const DateMatch& a, const DateMatch& b) { - if (a.HasYear() && b.HasYear()) { - if (a.year < b.year) return true; - if (a.year > b.year) return false; - } - - if (a.HasMonth() && b.HasMonth()) { - if (a.month < b.month) return true; - if (a.month > b.month) return false; - } - - if (a.HasDay() && b.HasDay()) { - if (a.day < b.day) return true; - if (a.day > b.day) return false; - } - - if (a.HasHour() && b.HasHour()) { - if (a.hour < b.hour) return true; - if (a.hour > b.hour) return false; - } - - if (a.HasMinute() && b.HasHour()) { - if (a.minute < b.hour) return true; - if (a.minute > b.hour) return false; - } - - if (a.HasSecond() && b.HasSecond()) { - if (a.second < b.hour) return true; - if (a.second > b.hour) return false; - } - - return false; -} - -void FillDateInstance(const DateMatch& date, - DatetimeParseResultSpan* instance) { - instance->span.first = date.begin; - instance->span.second = date.end; - instance->priority_score = date.GetAnnotatorPriorityScore(); - DatetimeParseResult datetime_parse_result; - date.FillDatetimeComponents(&datetime_parse_result.datetime_components); - instance->data.emplace_back(datetime_parse_result); -} - -void FillDateRangeInstance(const DateRangeMatch& range, - DatetimeParseResultSpan* instance) { - instance->span.first = range.begin; - instance->span.second = range.end; - instance->priority_score = range.GetAnnotatorPriorityScore(); - - // Filling from DatetimeParseResult. - instance->data.emplace_back(); - range.from.FillDatetimeComponents(&instance->data.back().datetime_components); - - // Filling to DatetimeParseResult. - instance->data.emplace_back(); - range.to.FillDatetimeComponents(&instance->data.back().datetime_components); -} - -namespace { -bool AnyOverlappedField(const DateMatch& prev, const DateMatch& next) { -#define Field(f) \ - if (prev.f && next.f) return true - Field(year_match); - Field(month_match); - Field(day_match); - Field(day_of_week_match); - Field(time_value_match); - Field(time_span_match); - Field(time_zone_name_match); - Field(time_zone_offset_match); - Field(relative_match); - Field(combined_digits_match); -#undef Field - return false; -} - -void MergeDateMatchImpl(const DateMatch& prev, DateMatch* next, - bool update_span) { -#define RM(f) \ - if (!next->f) next->f = prev.f - RM(year_match); - RM(month_match); - RM(day_match); - RM(day_of_week_match); - RM(time_value_match); - RM(time_span_match); - RM(time_zone_name_match); - RM(time_zone_offset_match); - RM(relative_match); - RM(combined_digits_match); -#undef RM - -#define RV(f) \ - if (next->f == NO_VAL) next->f = prev.f - RV(year); - RV(month); - RV(day); - RV(hour); - RV(minute); - RV(second); - RV(fraction_second); -#undef RV - -#define RE(f, v) \ - if (next->f == v) next->f = prev.f - RE(day_of_week, DayOfWeek_DOW_NONE); - RE(bc_ad, BCAD_BCAD_NONE); - RE(time_span_code, TimespanCode_TIMESPAN_CODE_NONE); - RE(time_zone_code, TimezoneCode_TIMEZONE_CODE_NONE); -#undef RE - - if (next->time_zone_offset == std::numeric_limits<int16>::min()) { - next->time_zone_offset = prev.time_zone_offset; - } - - next->priority = std::max(next->priority, prev.priority); - next->annotator_priority_score = - std::max(next->annotator_priority_score, prev.annotator_priority_score); - if (update_span) { - next->begin = std::min(next->begin, prev.begin); - next->end = std::max(next->end, prev.end); - } -} -} // namespace - -bool IsDateMatchMergeable(const DateMatch& prev, const DateMatch& next) { - // Do not merge if they share the same field. - if (AnyOverlappedField(prev, next)) { - return false; - } - - // It's impossible that both prev and next have relative date since it's - // excluded by overlapping check before. - if (prev.HasRelativeDate() || next.HasRelativeDate()) { - // If one of them is relative date, then we merge: - // - if relative match shouldn't have time, and always has DOW or day. - // - if not both relative match and non relative match has day. - // - if non relative match has time or day. - const DateMatch* rm = &prev; - const DateMatch* non_rm = &prev; - if (prev.HasRelativeDate()) { - non_rm = &next; - } else { - rm = &next; - } - - const RelativeMatch* relative_match = rm->relative_match; - // Relative Match should have day or DOW but no time. - if (!relative_match->HasDayFields() || - relative_match->HasTimeValueFields()) { - return false; - } - // Check if both relative match and non relative match has day. - if (non_rm->HasDateFields() && relative_match->HasDay()) { - return false; - } - // Non relative match should have either hour (time) or day (date). - if (!non_rm->HasHour() && !non_rm->HasDay()) { - return false; - } - } else { - // Only one match has date and another has time. - if ((prev.HasDateFields() && next.HasDateFields()) || - (prev.HasTimeFields() && next.HasTimeFields())) { - return false; - } - // DOW never be extracted as a single DateMatch except in RelativeMatch. So - // here, we always merge one with day and another one with hour. - if (!(prev.HasDay() || next.HasDay()) || - !(prev.HasHour() || next.HasHour())) { - return false; - } - } - return true; -} - -void MergeDateMatch(const DateMatch& prev, DateMatch* next, bool update_span) { - if (IsDateMatchMergeable(prev, *next)) { - MergeDateMatchImpl(prev, next, update_span); - } -} - -} // namespace dates -} // namespace libtextclassifier3 diff --git a/annotator/grammar/dates/utils/date-utils.h b/annotator/grammar/dates/utils/date-utils.h deleted file mode 100644 index 834e89f..0000000 --- a/annotator/grammar/dates/utils/date-utils.h +++ /dev/null
@@ -1,81 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_UTILS_H_ -#define LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_UTILS_H_ - -#include <stddef.h> -#include <stdint.h> - -#include <ctime> -#include <vector> - -#include "annotator/grammar/dates/annotations/annotation.h" -#include "annotator/grammar/dates/utils/date-match.h" -#include "utils/base/casts.h" - -namespace libtextclassifier3 { -namespace dates { - -bool IsLeapYear(int year); - -int8 GetLastDayOfMonth(int year, int month); - -// Normalizes hour value of the specified date using the specified time-span -// specification. Returns true if the original hour value (can be no-value) -// is compatible with the time-span and gets normalized successfully, or -// false otherwise. -bool NormalizeHourByTimeSpan(const TimeSpanSpec* ts_spec, DateMatch* date); - -// Returns true iff "a" is considered as a refinement of "b". For example, -// besides fully compatible fields, having more fields or higher priority. -bool IsRefinement(const DateMatch& a, const DateMatch& b); -bool IsRefinement(const DateRangeMatch& a, const DateRangeMatch& b); - -// Returns true iff "a" occurs strictly before "b" -bool IsPrecedent(const DateMatch& a, const DateMatch& b); - -// Fill DatetimeParseResult based on DateMatch object which is created from -// matched rule. The matched string is extracted from tokenizer which provides -// an interface to access the clean text based on the matched range. -void FillDateInstance(const DateMatch& date, DatetimeParseResult* instance); - -// Fill DatetimeParseResultSpan based on DateMatch object which is created from -// matched rule. The matched string is extracted from tokenizer which provides -// an interface to access the clean text based on the matched range. -void FillDateInstance(const DateMatch& date, DatetimeParseResultSpan* instance); - -// Fill DatetimeParseResultSpan based on DateRangeMatch object which i screated -// from matched rule. -void FillDateRangeInstance(const DateRangeMatch& range, - DatetimeParseResultSpan* instance); - -// Merge the fields in DateMatch prev to next if there is no overlapped field. -// If update_span is true, the span of next is also updated. -// e.g.: prev is 11am, next is: May 1, then the merged next is May 1, 11am -void MergeDateMatch(const DateMatch& prev, DateMatch* next, bool update_span); - -// If DateMatches have no overlapped field, then they could be merged as the -// following rules: -// -- If both don't have relative match and one DateMatch has day but another -// DateMatch has hour. -// -- If one have relative match then follow the rules in code. -// It's impossible to get DateMatch which only has DOW and not in relative -// match according to current rules. -bool IsDateMatchMergeable(const DateMatch& prev, const DateMatch& next); -} // namespace dates -} // namespace libtextclassifier3 - -#endif // LIBTEXTCLASSIFIER_ANNOTATOR_GRAMMAR_DATES_UTILS_DATE_UTILS_H_ diff --git a/annotator/grammar/grammar-annotator.cc b/annotator/grammar/grammar-annotator.cc index 38f5709..d4df327 100644 --- a/annotator/grammar/grammar-annotator.cc +++ b/annotator/grammar/grammar-annotator.cc
@@ -18,12 +18,8 @@ #include "annotator/feature-processor.h" #include "annotator/grammar/utils.h" #include "annotator/types.h" +#include "utils/base/arena.h" #include "utils/base/logging.h" -#include "utils/grammar/callback-delegate.h" -#include "utils/grammar/match.h" -#include "utils/grammar/matcher.h" -#include "utils/grammar/rules-utils.h" -#include "utils/grammar/types.h" #include "utils/normalization.h" #include "utils/optional.h" #include "utils/utf8/unicodetext.h" @@ -31,447 +27,296 @@ namespace libtextclassifier3 { namespace { -// Returns the unicode codepoint offsets in a utf8 encoded text. -std::vector<UnicodeText::const_iterator> UnicodeCodepointOffsets( - const UnicodeText& text) { - std::vector<UnicodeText::const_iterator> offsets; - for (auto it = text.begin(); it != text.end(); it++) { - offsets.push_back(it); +// Retrieves all capturing nodes from a parse tree. +std::unordered_map<uint16, const grammar::ParseTree*> GetCapturingNodes( + const grammar::ParseTree* parse_tree) { + std::unordered_map<uint16, const grammar::ParseTree*> capturing_nodes; + for (const grammar::MappingNode* mapping_node : + grammar::SelectAllOfType<grammar::MappingNode>( + parse_tree, grammar::ParseTree::Type::kMapping)) { + capturing_nodes[mapping_node->id] = mapping_node; } - offsets.push_back(text.end()); - return offsets; + return capturing_nodes; +} + +// Computes the selection boundaries from a parse tree. +CodepointSpan MatchSelectionBoundaries( + const grammar::ParseTree* parse_tree, + const GrammarModel_::RuleClassificationResult* classification) { + if (classification->capturing_group() == nullptr) { + // Use full match as selection span. + return parse_tree->codepoint_span; + } + + // Set information from capturing matches. + CodepointSpan span{kInvalidIndex, kInvalidIndex}; + std::unordered_map<uint16, const grammar::ParseTree*> capturing_nodes = + GetCapturingNodes(parse_tree); + + // Compute span boundaries. + for (int i = 0; i < classification->capturing_group()->size(); i++) { + auto it = capturing_nodes.find(i); + if (it == capturing_nodes.end()) { + // Capturing group is not active, skip. + continue; + } + const CapturingGroup* group = classification->capturing_group()->Get(i); + if (group->extend_selection()) { + if (span.first == kInvalidIndex) { + span = it->second->codepoint_span; + } else { + span.first = std::min(span.first, it->second->codepoint_span.first); + span.second = std::max(span.second, it->second->codepoint_span.second); + } + } + } + return span; } } // namespace -class GrammarAnnotatorCallbackDelegate : public grammar::CallbackDelegate { - public: - explicit GrammarAnnotatorCallbackDelegate( - const UniLib* unilib, const GrammarModel* model, - const MutableFlatbufferBuilder* entity_data_builder, const ModeFlag mode) - : unilib_(*unilib), - model_(model), - entity_data_builder_(entity_data_builder), - mode_(mode) {} - - // Handles a grammar rule match in the annotator grammar. - void MatchFound(const grammar::Match* match, grammar::CallbackId type, - int64 value, grammar::Matcher* matcher) override { - switch (static_cast<GrammarAnnotator::Callback>(type)) { - case GrammarAnnotator::Callback::kRuleMatch: { - HandleRuleMatch(match, /*rule_id=*/value); - return; - } - default: - grammar::CallbackDelegate::MatchFound(match, type, value, matcher); - } - } - - // Deduplicate and populate annotations from grammar matches. - bool GetAnnotations(const std::vector<UnicodeText::const_iterator>& text, - std::vector<AnnotatedSpan>* annotations) const { - for (const grammar::Derivation& candidate : - grammar::DeduplicateDerivations(candidates_)) { - // Check that assertions are fulfilled. - if (!grammar::VerifyAssertions(candidate.match)) { - continue; - } - if (!AddAnnotatedSpanFromMatch(text, candidate, annotations)) { - return false; - } - } - return true; - } - - bool GetTextSelection(const std::vector<UnicodeText::const_iterator>& text, - const CodepointSpan& selection, AnnotatedSpan* result) { - std::vector<grammar::Derivation> selection_candidates; - // Deduplicate and verify matches. - auto maybe_interpretation = GetBestValidInterpretation( - grammar::DeduplicateDerivations(GetOverlappingRuleMatches( - selection, candidates_, /*only_exact_overlap=*/false))); - if (!maybe_interpretation.has_value()) { - return false; - } - const GrammarModel_::RuleClassificationResult* interpretation; - const grammar::Match* match; - std::tie(interpretation, match) = maybe_interpretation.value(); - return InstantiateAnnotatedSpanFromInterpretation(text, interpretation, - match, result); - } - - // Provides a classification results from the grammar matches. - bool GetClassification(const std::vector<UnicodeText::const_iterator>& text, - const CodepointSpan& selection, - ClassificationResult* classification) const { - // Deduplicate and verify matches. - auto maybe_interpretation = GetBestValidInterpretation( - grammar::DeduplicateDerivations(GetOverlappingRuleMatches( - selection, candidates_, /*only_exact_overlap=*/true))); - if (!maybe_interpretation.has_value()) { - return false; - } - - // Instantiate result. - const GrammarModel_::RuleClassificationResult* interpretation; - const grammar::Match* match; - std::tie(interpretation, match) = maybe_interpretation.value(); - return InstantiateClassificationInterpretation(text, interpretation, match, - classification); - } - - private: - // Handles annotation/selection/classification rule matches. - void HandleRuleMatch(const grammar::Match* match, const int64 rule_id) { - if ((model_->rule_classification_result()->Get(rule_id)->enabled_modes() & - mode_) != 0) { - candidates_.push_back(grammar::Derivation{match, rule_id}); - } - } - - // Computes the selection boundaries from a grammar match. - CodepointSpan MatchSelectionBoundaries( - const grammar::Match* match, - const GrammarModel_::RuleClassificationResult* classification) const { - if (classification->capturing_group() == nullptr) { - // Use full match as selection span. - return match->codepoint_span; - } - - // Set information from capturing matches. - CodepointSpan span{kInvalidIndex, kInvalidIndex}; - // Gather active capturing matches. - std::unordered_map<uint16, const grammar::Match*> capturing_matches; - for (const grammar::MappingMatch* match : - grammar::SelectAllOfType<grammar::MappingMatch>( - match, grammar::Match::kMappingMatch)) { - capturing_matches[match->id] = match; - } - - // Compute span boundaries. - for (int i = 0; i < classification->capturing_group()->size(); i++) { - auto it = capturing_matches.find(i); - if (it == capturing_matches.end()) { - // Capturing group is not active, skip. - continue; - } - const CapturingGroup* group = classification->capturing_group()->Get(i); - if (group->extend_selection()) { - if (span.first == kInvalidIndex) { - span = it->second->codepoint_span; - } else { - span.first = std::min(span.first, it->second->codepoint_span.first); - span.second = - std::max(span.second, it->second->codepoint_span.second); - } - } - } - return span; - } - - // Filters out results that do not overlap with a reference span. - std::vector<grammar::Derivation> GetOverlappingRuleMatches( - const CodepointSpan& selection, - const std::vector<grammar::Derivation>& candidates, - const bool only_exact_overlap) const { - std::vector<grammar::Derivation> result; - for (const grammar::Derivation& candidate : candidates) { - // Discard matches that do not match the selection. - // Simple check. - if (!SpansOverlap(selection, candidate.match->codepoint_span)) { - continue; - } - - // Compute exact selection boundaries (without assertions and - // non-capturing parts). - const CodepointSpan span = MatchSelectionBoundaries( - candidate.match, - model_->rule_classification_result()->Get(candidate.rule_id)); - if (!SpansOverlap(selection, span) || - (only_exact_overlap && span != selection)) { - continue; - } - result.push_back(candidate); - } - return result; - } - - // Returns the best valid interpretation of a set of candidate matches. - Optional<std::pair<const GrammarModel_::RuleClassificationResult*, - const grammar::Match*>> - GetBestValidInterpretation( - const std::vector<grammar::Derivation>& candidates) const { - const GrammarModel_::RuleClassificationResult* best_interpretation = - nullptr; - const grammar::Match* best_match = nullptr; - for (const grammar::Derivation& candidate : candidates) { - if (!grammar::VerifyAssertions(candidate.match)) { - continue; - } - const GrammarModel_::RuleClassificationResult* - rule_classification_result = - model_->rule_classification_result()->Get(candidate.rule_id); - if (best_interpretation == nullptr || - best_interpretation->priority_score() < - rule_classification_result->priority_score()) { - best_interpretation = rule_classification_result; - best_match = candidate.match; - } - } - - // No valid interpretation found. - Optional<std::pair<const GrammarModel_::RuleClassificationResult*, - const grammar::Match*>> - result; - if (best_interpretation != nullptr) { - result = {best_interpretation, best_match}; - } - return result; - } - - // Instantiates an annotated span from a rule match and appends it to the - // result. - bool AddAnnotatedSpanFromMatch( - const std::vector<UnicodeText::const_iterator>& text, - const grammar::Derivation& candidate, - std::vector<AnnotatedSpan>* result) const { - if (candidate.rule_id < 0 || - candidate.rule_id >= model_->rule_classification_result()->size()) { - TC3_LOG(INFO) << "Invalid rule id."; - return false; - } - const GrammarModel_::RuleClassificationResult* interpretation = - model_->rule_classification_result()->Get(candidate.rule_id); - result->emplace_back(); - return InstantiateAnnotatedSpanFromInterpretation( - text, interpretation, candidate.match, &result->back()); - } - - bool InstantiateAnnotatedSpanFromInterpretation( - const std::vector<UnicodeText::const_iterator>& text, - const GrammarModel_::RuleClassificationResult* interpretation, - const grammar::Match* match, AnnotatedSpan* result) const { - result->span = MatchSelectionBoundaries(match, interpretation); - ClassificationResult classification; - if (!InstantiateClassificationInterpretation(text, interpretation, match, - &classification)) { - return false; - } - result->classification.push_back(classification); - return true; - } - - // Instantiates a classification result from a rule match. - bool InstantiateClassificationInterpretation( - const std::vector<UnicodeText::const_iterator>& text, - const GrammarModel_::RuleClassificationResult* interpretation, - const grammar::Match* match, ClassificationResult* classification) const { - classification->collection = interpretation->collection_name()->str(); - classification->score = interpretation->target_classification_score(); - classification->priority_score = interpretation->priority_score(); - - // Assemble entity data. - if (entity_data_builder_ == nullptr) { - return true; - } - std::unique_ptr<MutableFlatbuffer> entity_data = - entity_data_builder_->NewRoot(); - if (interpretation->serialized_entity_data() != nullptr) { - entity_data->MergeFromSerializedFlatbuffer( - StringPiece(interpretation->serialized_entity_data()->data(), - interpretation->serialized_entity_data()->size())); - } - if (interpretation->entity_data() != nullptr) { - entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>( - interpretation->entity_data())); - } - - // Populate entity data from the capturing matches. - if (interpretation->capturing_group() != nullptr) { - // Gather active capturing matches. - std::unordered_map<uint16, const grammar::Match*> capturing_matches; - for (const grammar::MappingMatch* match : - grammar::SelectAllOfType<grammar::MappingMatch>( - match, grammar::Match::kMappingMatch)) { - capturing_matches[match->id] = match; - } - for (int i = 0; i < interpretation->capturing_group()->size(); i++) { - auto it = capturing_matches.find(i); - if (it == capturing_matches.end()) { - // Capturing group is not active, skip. - continue; - } - const CapturingGroup* group = interpretation->capturing_group()->Get(i); - - // Add static entity data. - if (group->serialized_entity_data() != nullptr) { - entity_data->MergeFromSerializedFlatbuffer( - StringPiece(interpretation->serialized_entity_data()->data(), - interpretation->serialized_entity_data()->size())); - } - - // Set entity field from captured text. - if (group->entity_field_path() != nullptr) { - const grammar::Match* capturing_match = it->second; - StringPiece group_text = StringPiece( - text[capturing_match->codepoint_span.first].utf8_data(), - text[capturing_match->codepoint_span.second].utf8_data() - - text[capturing_match->codepoint_span.first].utf8_data()); - UnicodeText normalized_group_text = - UTF8ToUnicodeText(group_text, /*do_copy=*/false); - if (group->normalization_options() != nullptr) { - normalized_group_text = NormalizeText( - unilib_, group->normalization_options(), normalized_group_text); - } - if (!entity_data->ParseAndSet(group->entity_field_path(), - normalized_group_text.ToUTF8String())) { - TC3_LOG(ERROR) << "Could not set entity data from capturing match."; - return false; - } - } - } - } - - if (entity_data && entity_data->HasExplicitlySetFields()) { - classification->serialized_entity_data = entity_data->Serialize(); - } - return true; - } - - const UniLib& unilib_; - const GrammarModel* model_; - const MutableFlatbufferBuilder* entity_data_builder_; - const ModeFlag mode_; - - // All annotation/selection/classification rule match candidates. - // Grammar rule matches are recorded, deduplicated and then instantiated. - std::vector<grammar::Derivation> candidates_; -}; - GrammarAnnotator::GrammarAnnotator( const UniLib* unilib, const GrammarModel* model, const MutableFlatbufferBuilder* entity_data_builder) : unilib_(*unilib), model_(model), - lexer_(unilib, model->rules()), tokenizer_(BuildTokenizer(unilib, model->tokenizer_options())), entity_data_builder_(entity_data_builder), - rules_locales_(grammar::ParseRulesLocales(model->rules())) {} + analyzer_(unilib, model->rules(), &tokenizer_) {} + +// Filters out results that do not overlap with a reference span. +std::vector<grammar::Derivation> GrammarAnnotator::OverlappingDerivations( + const CodepointSpan& selection, + const std::vector<grammar::Derivation>& derivations, + const bool only_exact_overlap) const { + std::vector<grammar::Derivation> result; + for (const grammar::Derivation& derivation : derivations) { + // Discard matches that do not match the selection. + // Simple check. + if (!SpansOverlap(selection, derivation.parse_tree->codepoint_span)) { + continue; + } + + // Compute exact selection boundaries (without assertions and + // non-capturing parts). + const CodepointSpan span = MatchSelectionBoundaries( + derivation.parse_tree, + model_->rule_classification_result()->Get(derivation.rule_id)); + if (!SpansOverlap(selection, span) || + (only_exact_overlap && span != selection)) { + continue; + } + result.push_back(derivation); + } + return result; +} + +bool GrammarAnnotator::InstantiateAnnotatedSpanFromDerivation( + const grammar::TextContext& input_context, + const grammar::ParseTree* parse_tree, + const GrammarModel_::RuleClassificationResult* interpretation, + AnnotatedSpan* result) const { + result->span = MatchSelectionBoundaries(parse_tree, interpretation); + ClassificationResult classification; + if (!InstantiateClassificationFromDerivation( + input_context, parse_tree, interpretation, &classification)) { + return false; + } + result->classification.push_back(classification); + return true; +} + +// Instantiates a classification result from a rule match. +bool GrammarAnnotator::InstantiateClassificationFromDerivation( + const grammar::TextContext& input_context, + const grammar::ParseTree* parse_tree, + const GrammarModel_::RuleClassificationResult* interpretation, + ClassificationResult* classification) const { + classification->collection = interpretation->collection_name()->str(); + classification->score = interpretation->target_classification_score(); + classification->priority_score = interpretation->priority_score(); + + // Assemble entity data. + if (entity_data_builder_ == nullptr) { + return true; + } + std::unique_ptr<MutableFlatbuffer> entity_data = + entity_data_builder_->NewRoot(); + if (interpretation->serialized_entity_data() != nullptr) { + entity_data->MergeFromSerializedFlatbuffer( + StringPiece(interpretation->serialized_entity_data()->data(), + interpretation->serialized_entity_data()->size())); + } + if (interpretation->entity_data() != nullptr) { + entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>( + interpretation->entity_data())); + } + + // Populate entity data from the capturing matches. + if (interpretation->capturing_group() != nullptr) { + // Gather active capturing matches. + std::unordered_map<uint16, const grammar::ParseTree*> capturing_nodes = + GetCapturingNodes(parse_tree); + + for (int i = 0; i < interpretation->capturing_group()->size(); i++) { + auto it = capturing_nodes.find(i); + if (it == capturing_nodes.end()) { + // Capturing group is not active, skip. + continue; + } + const CapturingGroup* group = interpretation->capturing_group()->Get(i); + + // Add static entity data. + if (group->serialized_entity_data() != nullptr) { + entity_data->MergeFromSerializedFlatbuffer( + StringPiece(interpretation->serialized_entity_data()->data(), + interpretation->serialized_entity_data()->size())); + } + + // Set entity field from captured text. + if (group->entity_field_path() != nullptr) { + const grammar::ParseTree* capturing_match = it->second; + UnicodeText match_text = + input_context.Span(capturing_match->codepoint_span); + if (group->normalization_options() != nullptr) { + match_text = NormalizeText(unilib_, group->normalization_options(), + match_text); + } + if (!entity_data->ParseAndSet(group->entity_field_path(), + match_text.ToUTF8String())) { + TC3_LOG(ERROR) << "Could not set entity data from capturing match."; + return false; + } + } + } + } + + if (entity_data && entity_data->HasExplicitlySetFields()) { + classification->serialized_entity_data = entity_data->Serialize(); + } + return true; +} bool GrammarAnnotator::Annotate(const std::vector<Locale>& locales, const UnicodeText& text, std::vector<AnnotatedSpan>* result) const { - if (model_ == nullptr || model_->rules() == nullptr) { - // Nothing to do. - return true; + grammar::TextContext input_context = + analyzer_.BuildTextContextForInput(text, locales); + + UnsafeArena arena(/*block_size=*/16 << 10); + + for (const grammar::Derivation& derivation : ValidDeduplicatedDerivations( + analyzer_.parser().Parse(input_context, &arena))) { + const GrammarModel_::RuleClassificationResult* interpretation = + model_->rule_classification_result()->Get(derivation.rule_id); + if ((interpretation->enabled_modes() & ModeFlag_ANNOTATION) == 0) { + continue; + } + result->emplace_back(); + if (!InstantiateAnnotatedSpanFromDerivation( + input_context, derivation.parse_tree, interpretation, + &result->back())) { + return false; + } } - // Select locale matching rules. - std::vector<const grammar::RulesSet_::Rules*> locale_rules = - SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales); - if (locale_rules.empty()) { - // Nothing to do. - return true; - } - - // Run the grammar. - GrammarAnnotatorCallbackDelegate callback_handler( - &unilib_, model_, entity_data_builder_, - /*mode=*/ModeFlag_ANNOTATION); - grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules, - &callback_handler); - lexer_.Process(text, tokenizer_.Tokenize(text), /*annotations=*/nullptr, - &matcher); - - // Populate results. - return callback_handler.GetAnnotations(UnicodeCodepointOffsets(text), result); + return true; } bool GrammarAnnotator::SuggestSelection(const std::vector<Locale>& locales, const UnicodeText& text, const CodepointSpan& selection, AnnotatedSpan* result) const { - if (model_ == nullptr || model_->rules() == nullptr || - selection == CodepointSpan{kInvalidIndex, kInvalidIndex}) { - // Nothing to do. + if (!selection.IsValid() || selection.IsEmpty()) { return false; } - // Select locale matching rules. - std::vector<const grammar::RulesSet_::Rules*> locale_rules = - SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales); - if (locale_rules.empty()) { - // Nothing to do. - return true; + grammar::TextContext input_context = + analyzer_.BuildTextContextForInput(text, locales); + + UnsafeArena arena(/*block_size=*/16 << 10); + + const GrammarModel_::RuleClassificationResult* best_interpretation = nullptr; + const grammar::ParseTree* best_match = nullptr; + for (const grammar::Derivation& derivation : + ValidDeduplicatedDerivations(OverlappingDerivations( + selection, analyzer_.parser().Parse(input_context, &arena), + /*only_exact_overlap=*/false))) { + const GrammarModel_::RuleClassificationResult* interpretation = + model_->rule_classification_result()->Get(derivation.rule_id); + if ((interpretation->enabled_modes() & ModeFlag_SELECTION) == 0) { + continue; + } + if (best_interpretation == nullptr || + interpretation->priority_score() > + best_interpretation->priority_score()) { + best_interpretation = interpretation; + best_match = derivation.parse_tree; + } } - // Run the grammar. - GrammarAnnotatorCallbackDelegate callback_handler( - &unilib_, model_, entity_data_builder_, - /*mode=*/ModeFlag_SELECTION); - grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules, - &callback_handler); - lexer_.Process(text, tokenizer_.Tokenize(text), /*annotations=*/nullptr, - &matcher); + if (best_interpretation == nullptr) { + return false; + } - // Populate the result. - return callback_handler.GetTextSelection(UnicodeCodepointOffsets(text), - selection, result); + return InstantiateAnnotatedSpanFromDerivation(input_context, best_match, + best_interpretation, result); } bool GrammarAnnotator::ClassifyText( const std::vector<Locale>& locales, const UnicodeText& text, const CodepointSpan& selection, ClassificationResult* classification_result) const { - if (model_ == nullptr || model_->rules() == nullptr || - selection == CodepointSpan{kInvalidIndex, kInvalidIndex}) { + if (!selection.IsValid() || selection.IsEmpty()) { // Nothing to do. return false; } - // Select locale matching rules. - std::vector<const grammar::RulesSet_::Rules*> locale_rules = - SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales); - if (locale_rules.empty()) { - // Nothing to do. + grammar::TextContext input_context = + analyzer_.BuildTextContextForInput(text, locales); + + if (const TokenSpan context_span = CodepointSpanToTokenSpan( + input_context.tokens, selection, + /*snap_boundaries_to_containing_tokens=*/true); + context_span.IsValid()) { + if (model_->context_left_num_tokens() != kInvalidIndex) { + input_context.context_span.first = + std::max(0, context_span.first - model_->context_left_num_tokens()); + } + if (model_->context_right_num_tokens() != kInvalidIndex) { + input_context.context_span.second = + std::min(static_cast<int>(input_context.tokens.size()), + context_span.second + model_->context_right_num_tokens()); + } + } + + UnsafeArena arena(/*block_size=*/16 << 10); + + const GrammarModel_::RuleClassificationResult* best_interpretation = nullptr; + const grammar::ParseTree* best_match = nullptr; + for (const grammar::Derivation& derivation : + ValidDeduplicatedDerivations(OverlappingDerivations( + selection, analyzer_.parser().Parse(input_context, &arena), + /*only_exact_overlap=*/true))) { + const GrammarModel_::RuleClassificationResult* interpretation = + model_->rule_classification_result()->Get(derivation.rule_id); + if ((interpretation->enabled_modes() & ModeFlag_CLASSIFICATION) == 0) { + continue; + } + if (best_interpretation == nullptr || + interpretation->priority_score() > + best_interpretation->priority_score()) { + best_interpretation = interpretation; + best_match = derivation.parse_tree; + } + } + + if (best_interpretation == nullptr) { return false; } - // Run the grammar. - GrammarAnnotatorCallbackDelegate callback_handler( - &unilib_, model_, entity_data_builder_, - /*mode=*/ModeFlag_CLASSIFICATION); - grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules, - &callback_handler); - - const std::vector<Token> tokens = tokenizer_.Tokenize(text); - if (model_->context_left_num_tokens() == -1 && - model_->context_right_num_tokens() == -1) { - // Use all tokens. - lexer_.Process(text, tokens, /*annotations=*/{}, &matcher); - } else { - TokenSpan context_span = CodepointSpanToTokenSpan( - tokens, selection, /*snap_boundaries_to_containing_tokens=*/true); - std::vector<Token>::const_iterator begin = tokens.begin(); - std::vector<Token>::const_iterator end = tokens.begin(); - if (model_->context_left_num_tokens() != -1) { - std::advance(begin, std::max(0, context_span.first - - model_->context_left_num_tokens())); - } - if (model_->context_right_num_tokens() == -1) { - end = tokens.end(); - } else { - std::advance(end, std::min(static_cast<int>(tokens.size()), - context_span.second + - model_->context_right_num_tokens())); - } - lexer_.Process(text, begin, end, - /*annotations=*/nullptr, &matcher); - } - - // Populate result. - return callback_handler.GetClassification(UnicodeCodepointOffsets(text), - selection, classification_result); + return InstantiateClassificationFromDerivation( + input_context, best_match, best_interpretation, classification_result); } } // namespace libtextclassifier3 diff --git a/annotator/grammar/grammar-annotator.h b/annotator/grammar/grammar-annotator.h index 2ac6988..08b3040 100644 --- a/annotator/grammar/grammar-annotator.h +++ b/annotator/grammar/grammar-annotator.h
@@ -21,7 +21,9 @@ #include "annotator/model_generated.h" #include "annotator/types.h" #include "utils/flatbuffers/mutable.h" -#include "utils/grammar/lexer.h" +#include "utils/grammar/analyzer.h" +#include "utils/grammar/evaluated-derivation.h" +#include "utils/grammar/text-context.h" #include "utils/i18n/locale.h" #include "utils/tokenizer.h" #include "utils/utf8/unicodetext.h" @@ -32,10 +34,6 @@ // Grammar backed annotator. class GrammarAnnotator { public: - enum class Callback : grammar::CallbackId { - kRuleMatch = 1, - }; - explicit GrammarAnnotator( const UniLib* unilib, const GrammarModel* model, const MutableFlatbufferBuilder* entity_data_builder); @@ -58,14 +56,31 @@ AnnotatedSpan* result) const; private: + // Filters out derivations that do not overlap with a reference span. + std::vector<grammar::Derivation> OverlappingDerivations( + const CodepointSpan& selection, + const std::vector<grammar::Derivation>& derivations, + const bool only_exact_overlap) const; + + // Fills out an annotated span from a grammar match result. + bool InstantiateAnnotatedSpanFromDerivation( + const grammar::TextContext& input_context, + const grammar::ParseTree* parse_tree, + const GrammarModel_::RuleClassificationResult* interpretation, + AnnotatedSpan* result) const; + + // Instantiates a classification result from a rule match. + bool InstantiateClassificationFromDerivation( + const grammar::TextContext& input_context, + const grammar::ParseTree* parse_tree, + const GrammarModel_::RuleClassificationResult* interpretation, + ClassificationResult* classification) const; + const UniLib& unilib_; const GrammarModel* model_; - const grammar::Lexer lexer_; const Tokenizer tokenizer_; const MutableFlatbufferBuilder* entity_data_builder_; - - // Pre-parsed locales of the rules. - const std::vector<std::vector<Locale>> rules_locales_; + const grammar::Analyzer analyzer_; }; } // namespace libtextclassifier3 diff --git a/annotator/knowledge/knowledge-engine-dummy.h b/annotator/knowledge/knowledge-engine-dummy.h index ecd687a..0320c53 100644 --- a/annotator/knowledge/knowledge-engine-dummy.h +++ b/annotator/knowledge/knowledge-engine-dummy.h
@@ -51,12 +51,13 @@ return true; } - Status ChunkMultipleSpans(const std::vector<std::string>& text_fragments, - AnnotationUsecase annotation_usecase, - const Optional<LocationContext>& location_context, - const Permissions& permissions, - const AnnotateMode annotate_mode, - Annotations* results) const { + Status ChunkMultipleSpans( + const std::vector<std::string>& text_fragments, + const std::vector<FragmentMetadata>& fragment_metadata, + AnnotationUsecase annotation_usecase, + const Optional<LocationContext>& location_context, + const Permissions& permissions, const AnnotateMode annotate_mode, + Annotations* results) const { return Status::OK; } diff --git a/annotator/knowledge/knowledge-engine-types.h b/annotator/knowledge/knowledge-engine-types.h index 6757a2e..2d30f8c 100644 --- a/annotator/knowledge/knowledge-engine-types.h +++ b/annotator/knowledge/knowledge-engine-types.h
@@ -20,6 +20,11 @@ enum AnnotateMode { kEntityAnnotation, kEntityAndTopicalityAnnotation }; +struct FragmentMetadata { + float relative_bounding_box_top; + float relative_bounding_box_height; +}; + } // namespace libtextclassifier3 #endif // LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_TYPES_H_ diff --git a/annotator/model.fbs b/annotator/model.fbs index 64e0911..60c3688 100755 --- a/annotator/model.fbs +++ b/annotator/model.fbs
@@ -13,18 +13,17 @@ // limitations under the License. // -include "utils/flatbuffers/flatbuffers.fbs"; include "utils/container/bit-vector.fbs"; -include "annotator/experimental/experimental.fbs"; -include "utils/grammar/rules.fbs"; -include "utils/tokenizer.fbs"; -include "annotator/entity-data.fbs"; -include "utils/normalization.fbs"; -include "utils/zlib/buffer.fbs"; -include "utils/resources.fbs"; include "utils/intents/intent-config.fbs"; +include "utils/normalization.fbs"; +include "utils/flatbuffers/flatbuffers.fbs"; +include "annotator/experimental/experimental.fbs"; +include "utils/resources.fbs"; +include "annotator/entity-data.fbs"; include "utils/codepoint-range.fbs"; -include "annotator/grammar/dates/dates.fbs"; +include "utils/tokenizer.fbs"; +include "utils/zlib/buffer.fbs"; +include "utils/grammar/rules.fbs"; file_identifier "TC2 "; @@ -373,85 +372,6 @@ tokenize_on_script_change:bool = false; } -// Options for grammar date/datetime/date range annotations. -namespace libtextclassifier3.GrammarDatetimeModel_; -table AnnotationOptions { - // If enabled, extract special day offset like today, yesterday, etc. - enable_special_day_offset:bool = true; - - // If true, merge the adjacent day of week, time and date. e.g. - // "20/2/2016 at 8pm" is extracted as a single instance instead of two - // instance: "20/2/2016" and "8pm". - merge_adjacent_components:bool = true; - - // List the extra id of requested dates. - extra_requested_dates:[string]; - - // If true, try to include preposition to the extracted annotation. e.g. - // "at 6pm". if it's false, only 6pm is included. offline-actions has - // special requirements to include preposition. - include_preposition:bool = true; - - // If enabled, extract range in date annotator. - // input: Monday, 5-6pm - // If the flag is true, The extracted annotation only contains 1 range - // instance which is from Monday 5pm to 6pm. - // If the flag is false, The extracted annotation contains two date - // instance: "Monday" and "6pm". - enable_date_range:bool = true; - reserved_6:int16 (deprecated); - - // If enabled, the rule priority score is used to set the priority score of - // the annotation. - // In case of false the annotation priority score is set from - // GrammarDatetimeModel's priority_score - use_rule_priority_score:bool = false; - - // If enabled, annotator will try to resolve the ambiguity by generating - // possible alternative interpretations of the input text - // e.g. '9:45' will be resolved to '9:45 AM' and '9:45 PM'. - generate_alternative_interpretations_when_ambiguous:bool; - - // List of spans which grammar will ignore during the match e.g. if - // “@” is in the allowed span list and input is “12 March @ 12PM” then “@” - // will be ignored and 12 March @ 12PM will be translate to - // {Day:12 Month: March Hour: 12 MERIDIAN: PM}. - // This can also be achieved by adding additional rules e.g. - // <Digit_Day> <Month> <Time> - // <Digit_Day> <Month> @ <Time> - // Though this is doable in the grammar but requires multiple rules, this - // list enables the rule to represent multiple rules. - ignored_spans:[string]; -} - -namespace libtextclassifier3; -table GrammarDatetimeModel { - // List of BCP 47 locale strings representing all locales supported by the - // model. - locales:[string]; - - // If true, will give only future dates (when the day is not specified). - prefer_future_for_unspecified_date:bool = false; - - // Grammar specific tokenizer options. - grammar_tokenizer_options:GrammarTokenizerOptions; - - // The modes for which to apply the grammars. - enabled_modes:ModeFlag = ALL; - - // The datetime grammar rules. - datetime_rules:dates.DatetimeRules; - - // The final score to assign to the results of grammar model - target_classification_score:float = 1; - - // The priority score used for conflict resolution with the other models. - priority_score:float = 0; - - // Options for grammar annotations. - annotation_options:GrammarDatetimeModel_.AnnotationOptions; -} - namespace libtextclassifier3.DatetimeModelLibrary_; table Item { key:string; @@ -666,7 +586,7 @@ triggering_locales:string; embedding_pruning_mask:Model_.EmbeddingPruningMask; - grammar_datetime_model:GrammarDatetimeModel; + reserved_25:int16 (deprecated); contact_annotator_options:ContactAnnotatorOptions; money_parsing_options:MoneyParsingOptions; translate_annotator_options:TranslateAnnotatorOptions; diff --git a/annotator/pod_ner/pod-ner-dummy.h b/annotator/pod_ner/pod-ner-dummy.h index 8d90529..2c13dd0 100644 --- a/annotator/pod_ner/pod-ner-dummy.h +++ b/annotator/pod_ner/pod-ner-dummy.h
@@ -39,8 +39,8 @@ return true; } - AnnotatedSpan SuggestSelection(const UnicodeText &context, - CodepointSpan click) const { + bool SuggestSelection(const UnicodeText &context, CodepointSpan click, + AnnotatedSpan *result) const { return {}; } diff --git a/annotator/strip-unpaired-brackets.cc b/annotator/strip-unpaired-brackets.cc index c1c257d..b72db68 100644 --- a/annotator/strip-unpaired-brackets.cc +++ b/annotator/strip-unpaired-brackets.cc
@@ -21,23 +21,59 @@ #include "utils/utf8/unicodetext.h" namespace libtextclassifier3 { +namespace { -CodepointSpan StripUnpairedBrackets( - const UnicodeText::const_iterator& span_begin, - const UnicodeText::const_iterator& span_end, CodepointSpan span, - const UniLib& unilib) { - if (span_begin == span_end || !span.IsValid() || span.IsEmpty()) { +// Returns true if given codepoint is contained in the given span in context. +bool IsCodepointInSpan(const char32 codepoint, + const UnicodeText& context_unicode, + const CodepointSpan span) { + auto begin_it = context_unicode.begin(); + std::advance(begin_it, span.first); + auto end_it = context_unicode.begin(); + std::advance(end_it, span.second); + + return std::find(begin_it, end_it, codepoint) != end_it; +} + +// Returns the first codepoint of the span. +char32 FirstSpanCodepoint(const UnicodeText& context_unicode, + const CodepointSpan span) { + auto it = context_unicode.begin(); + std::advance(it, span.first); + return *it; +} + +// Returns the last codepoint of the span. +char32 LastSpanCodepoint(const UnicodeText& context_unicode, + const CodepointSpan span) { + auto it = context_unicode.begin(); + std::advance(it, span.second - 1); + return *it; +} + +} // namespace + +CodepointSpan StripUnpairedBrackets(const std::string& context, + CodepointSpan span, const UniLib& unilib) { + const UnicodeText context_unicode = + UTF8ToUnicodeText(context, /*do_copy=*/false); + return StripUnpairedBrackets(context_unicode, span, unilib); +} + +// If the first or the last codepoint of the given span is a bracket, the +// bracket is stripped if the span does not contain its corresponding paired +// version. +CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode, + CodepointSpan span, const UniLib& unilib) { + if (context_unicode.empty() || !span.IsValid() || span.IsEmpty()) { return span; } - UnicodeText::const_iterator begin = span_begin; - const UnicodeText::const_iterator end = span_end; - const char32 begin_char = *begin; + const char32 begin_char = FirstSpanCodepoint(context_unicode, span); const char32 paired_begin_char = unilib.GetPairedBracket(begin_char); if (paired_begin_char != begin_char) { if (!unilib.IsOpeningBracket(begin_char) || - std::find(begin, end, paired_begin_char) == end) { - ++begin; + !IsCodepointInSpan(paired_begin_char, context_unicode, span)) { ++span.first; } } @@ -46,11 +82,11 @@ return span; } - const char32 end_char = *std::prev(end); + const char32 end_char = LastSpanCodepoint(context_unicode, span); const char32 paired_end_char = unilib.GetPairedBracket(end_char); if (paired_end_char != end_char) { if (!unilib.IsClosingBracket(end_char) || - std::find(begin, end, paired_end_char) == end) { + !IsCodepointInSpan(paired_end_char, context_unicode, span)) { --span.second; } } @@ -65,21 +101,4 @@ return span; } -CodepointSpan StripUnpairedBrackets(const UnicodeText& context, - CodepointSpan span, const UniLib& unilib) { - if (!span.IsValid() || span.IsEmpty()) { - return span; - } - const UnicodeText span_text = UnicodeText::Substring( - context, span.first, span.second, /*do_copy=*/false); - return StripUnpairedBrackets(span_text.begin(), span_text.end(), span, - unilib); -} - -CodepointSpan StripUnpairedBrackets(const std::string& context, - CodepointSpan span, const UniLib& unilib) { - return StripUnpairedBrackets(UTF8ToUnicodeText(context, /*do_copy=*/false), - span, unilib); -} - } // namespace libtextclassifier3 diff --git a/annotator/strip-unpaired-brackets.h b/annotator/strip-unpaired-brackets.h index 6109a39..19e9819 100644 --- a/annotator/strip-unpaired-brackets.h +++ b/annotator/strip-unpaired-brackets.h
@@ -22,21 +22,14 @@ #include "utils/utf8/unilib.h" namespace libtextclassifier3 { - // If the first or the last codepoint of the given span is a bracket, the // bracket is stripped if the span does not contain its corresponding paired // version. -CodepointSpan StripUnpairedBrackets( - const UnicodeText::const_iterator& span_begin, - const UnicodeText::const_iterator& span_end, CodepointSpan span, - const UniLib& unilib); - -// Same as above but takes a UnicodeText instance for the span. -CodepointSpan StripUnpairedBrackets(const UnicodeText& context, +CodepointSpan StripUnpairedBrackets(const std::string& context, CodepointSpan span, const UniLib& unilib); -// Same as above but takes a string instance. -CodepointSpan StripUnpairedBrackets(const std::string& context, +// Same as above but takes UnicodeText instance directly. +CodepointSpan StripUnpairedBrackets(const UnicodeText& context_unicode, CodepointSpan span, const UniLib& unilib); } // namespace libtextclassifier3 diff --git a/annotator/translate/translate.cc b/annotator/translate/translate.cc index 054ead0..893a911 100644 --- a/annotator/translate/translate.cc +++ b/annotator/translate/translate.cc
@@ -13,8 +13,6 @@ // limitations under the License. // -#pragma GCC diagnostic ignored "-Wc++17-extensions" - #include "annotator/translate/translate.h" #include <memory>
diff --git a/annotator/types.h b/annotator/types.h index a826504..85daa22 100644 --- a/annotator/types.h +++ b/annotator/types.h
@@ -82,7 +82,8 @@ } bool IsValid() const { - return this->first != kInvalidIndex && this->second != kInvalidIndex; + return this->first != kInvalidIndex && this->second != kInvalidIndex && + this->first <= this->second && this->first >= 0; } bool IsEmpty() const { return this->first == this->second; } @@ -281,9 +282,9 @@ SECOND = 8, // Meridiem field where 0 == AM, 1 == PM. MERIDIEM = 9, - // Number of hours offset from UTC this date time is in. + // Offset in number of minutes from UTC this date time is in. ZONE_OFFSET = 10, - // Number of hours offest for DST. + // Offset in number of hours for DST. DST_OFFSET = 11, }; @@ -429,7 +430,8 @@ std::string serialized_knowledge_result; ContactPointer contact_pointer; std::string contact_name, contact_given_name, contact_family_name, - contact_nickname, contact_email_address, contact_phone_number, contact_id; + contact_nickname, contact_email_address, contact_phone_number, + contact_account_type, contact_account_name, contact_id; std::string app_name, app_package_name; int64 numeric_value; double numeric_double_value; @@ -525,7 +527,7 @@ // If true and the model file supports that, the new vocab annotator is used // to annotate "Dictionary". Otherwise, we use the FFModel to do so. - bool use_vocab_annotator = false; + bool use_vocab_annotator = true; bool operator==(const BaseOptions& other) const { bool location_context_equality = this->location_context.has_value() == @@ -682,6 +684,8 @@ struct InputFragment { std::string text; + float bounding_box_top; + float bounding_box_height; // If present will override the AnnotationOptions reference time and timezone // when annotating this specific string fragment. diff --git a/lang_id/common/file/mmap.cc b/lang_id/common/file/mmap.cc index 0bfbea8..9835d2b 100644 --- a/lang_id/common/file/mmap.cc +++ b/lang_id/common/file/mmap.cc
@@ -159,6 +159,7 @@ SAFTM_LOG(ERROR) << "Error closing file descriptor: " << last_error; } } + private: const int fd_; @@ -198,12 +199,19 @@ } MmapHandle MmapFile(int fd, size_t offset_in_bytes, size_t size_in_bytes) { + // Make sure the offset is a multiple of the page size, as returned by + // sysconf(_SC_PAGE_SIZE); this is required by the man-page for mmap. + static const size_t kPageSize = sysconf(_SC_PAGE_SIZE); + const size_t aligned_offset = (offset_in_bytes / kPageSize) * kPageSize; + const size_t alignment_shift = offset_in_bytes - aligned_offset; + const size_t aligned_length = size_in_bytes + alignment_shift; + void *mmap_addr = mmap( // Let system pick address for mmapp-ed data. nullptr, - size_in_bytes, + aligned_length, // One can read / write the mapped data (but see MAP_PRIVATE below). // Normally, we expect only to read it, but in the future, we may want to @@ -217,14 +225,15 @@ // Descriptor of file to mmap. fd, - offset_in_bytes); + aligned_offset); if (mmap_addr == MAP_FAILED) { const std::string last_error = GetLastSystemError(); SAFTM_LOG(ERROR) << "Error while mmapping: " << last_error; return GetErrorMmapHandle(); } - return MmapHandle(mmap_addr, size_in_bytes); + return MmapHandle(static_cast<char *>(mmap_addr) + alignment_shift, + size_in_bytes); } bool Unmap(MmapHandle mmap_handle) { diff --git a/lang_id/script/tiny-script-detector.h b/lang_id/script/tiny-script-detector.h index 8eac366..5a9de5f 100644 --- a/lang_id/script/tiny-script-detector.h +++ b/lang_id/script/tiny-script-detector.h
@@ -73,12 +73,12 @@ // CPU, so it's better to use than int32. static const unsigned int kGreekStart = 0x370; - // Commented out (unsued in the code): kGreekEnd = 0x3FF; + // Commented out (unused in the code): kGreekEnd = 0x3FF; static const unsigned int kCyrillicStart = 0x400; static const unsigned int kCyrillicEnd = 0x4FF; static const unsigned int kHebrewStart = 0x590; - // Commented out (unsued in the code): kHebrewEnd = 0x5FF; + // Commented out (unused in the code): kHebrewEnd = 0x5FF; static const unsigned int kArabicStart = 0x600; static const unsigned int kArabicEnd = 0x6FF; const unsigned int codepoint = ((p[0] & 0x1F) << 6) | (p[1] & 0x3F); @@ -116,7 +116,7 @@ static const unsigned int kHiraganaStart = 0x3041; static const unsigned int kHiraganaEnd = 0x309F; - // Commented out (unsued in the code): kKatakanaStart = 0x30A0; + // Commented out (unused in the code): kKatakanaStart = 0x30A0; static const unsigned int kKatakanaEnd = 0x30FF; const unsigned int codepoint = ((p[0] & 0x0F) << 12) | ((p[1] & 0x3F) << 6) | (p[2] & 0x3F);
diff --git a/utils/base/statusor.h b/utils/base/statusor.h index 8af3d75..afc9389 100644 --- a/utils/base/statusor.h +++ b/utils/base/statusor.h
@@ -17,36 +17,6 @@ #define LIBTEXTCLASSIFIER_UTILS_BASE_STATUSOR_H_ #include <type_traits> - -namespace cros_tclib { - -// C++14 implementation of C++17's std::bool_constant. -// Copied from chromium/src/base/template_util.h -template <bool B> -using bool_constant = std::integral_constant<bool, B>; -// C++14 implementation of C++17's std::conjunction. -// Copied from chromium/src/base/template_util.h -template <typename...> - struct conjunction : std::true_type {}; - -template <typename B1> - struct conjunction<B1> : B1 {}; - -template <typename B1, typename... Bn> - struct conjunction<B1, Bn...> - : std::conditional_t<static_cast<bool>(B1::value), conjunction<Bn...>, B1> { -}; - -// C++14 implementation of C++17's std::negation. -// Copied from chromium/src/base/template_util.h -template <typename B> - struct negation : bool_constant<!static_cast<bool>(B::value)> {}; - -} // namespace cros_tclib - - - -#include <type_traits> #include <utility> #include "utils/base/logging.h" @@ -77,7 +47,7 @@ // Conversion copy constructor, T must be copy constructible from U. template <typename U, std::enable_if_t< - cros_tclib::conjunction<cros_tclib::negation<std::is_same<T, U>>, + std::conjunction<std::negation<std::is_same<T, U>>, std::is_constructible<T, const U&>, std::is_convertible<const U&, T>>::value, int> = 0> @@ -86,7 +56,7 @@ // Conversion move constructor, T must by move constructible from U. template < typename U, - std::enable_if_t<cros_tclib::conjunction<cros_tclib::negation<std::is_same<T, U>>, + std::enable_if_t<std::conjunction<std::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>, std::is_convertible<U&&, T>>::value, int> = 0> @@ -95,7 +65,7 @@ // Value conversion copy constructor, T must by copy constructible from U. template <typename U, std::enable_if_t< - cros_tclib::conjunction<cros_tclib::negation<std::is_same<T, U>>, + std::conjunction<std::negation<std::is_same<T, U>>, std::is_constructible<T, const U&>, std::is_convertible<const U&, T>>::value, int> = 0> @@ -104,7 +74,7 @@ // Value conversion move constructor, T must by move constructible from U. template < typename U, - std::enable_if_t<cros_tclib::conjunction<cros_tclib::negation<std::is_same<T, U>>, + std::enable_if_t<std::conjunction<std::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>, std::is_convertible<U&&, T>>::value, int> = 0> @@ -242,7 +212,7 @@ template <typename T> template < typename U, - std::enable_if_t<cros_tclib::conjunction<cros_tclib::negation<std::is_same<T, U>>, + std::enable_if_t<std::conjunction<std::negation<std::is_same<T, U>>, std::is_constructible<T, const U&>, std::is_convertible<const U&, T>>::value, int>> @@ -251,7 +221,7 @@ template <typename T> template <typename U, - std::enable_if_t<cros_tclib::conjunction<cros_tclib::negation<std::is_same<T, U>>, + std::enable_if_t<std::conjunction<std::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>, std::is_convertible<U&&, T>>::value, int>> @@ -261,7 +231,7 @@ template <typename T> template < typename U, - std::enable_if_t<cros_tclib::conjunction<cros_tclib::negation<std::is_same<T, U>>, + std::enable_if_t<std::conjunction<std::negation<std::is_same<T, U>>, std::is_constructible<T, const U&>, std::is_convertible<const U&, T>>::value, int>> @@ -269,7 +239,7 @@ template <typename T> template <typename U, - std::enable_if_t<cros_tclib::conjunction<cros_tclib::negation<std::is_same<T, U>>, + std::enable_if_t<std::conjunction<std::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>, std::is_convertible<U&&, T>>::value, int>> diff --git a/utils/flatbuffers/reflection.h b/utils/flatbuffers/reflection.h index 8c00a0e..1ac5e0a 100644 --- a/utils/flatbuffers/reflection.h +++ b/utils/flatbuffers/reflection.h
@@ -13,8 +13,6 @@ // limitations under the License. // -#pragma GCC diagnostic ignored "-Wc++17-extensions" - // Utility functions for working with FlatBuffers. #ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_REFLECTION_H_ @@ -86,6 +84,64 @@ inline const reflection::BaseType flatbuffers_base_type<StringPiece>::value = reflection::String; +template <reflection::BaseType> +struct flatbuffers_cpp_type; + +template <> +struct flatbuffers_cpp_type<reflection::BaseType::Bool> { + using value = bool; +}; + +template <> +struct flatbuffers_cpp_type<reflection::BaseType::Byte> { + using value = int8; +}; + +template <> +struct flatbuffers_cpp_type<reflection::BaseType::UByte> { + using value = uint8; +}; + +template <> +struct flatbuffers_cpp_type<reflection::BaseType::Short> { + using value = int16; +}; + +template <> +struct flatbuffers_cpp_type<reflection::BaseType::UShort> { + using value = uint16; +}; + +template <> +struct flatbuffers_cpp_type<reflection::BaseType::Int> { + using value = int32; +}; + +template <> +struct flatbuffers_cpp_type<reflection::BaseType::UInt> { + using value = uint32; +}; + +template <> +struct flatbuffers_cpp_type<reflection::BaseType::Long> { + using value = int64; +}; + +template <> +struct flatbuffers_cpp_type<reflection::BaseType::ULong> { + using value = uint64; +}; + +template <> +struct flatbuffers_cpp_type<reflection::BaseType::Float> { + using value = float; +}; + +template <> +struct flatbuffers_cpp_type<reflection::BaseType::Double> { + using value = double; +}; + // Gets the field information for a field name, returns nullptr if the // field was not defined. const reflection::Field* GetFieldOrNull(const reflection::Object* type, diff --git a/utils/grammar/analyzer.cc b/utils/grammar/analyzer.cc new file mode 100644 index 0000000..c390c3e --- /dev/null +++ b/utils/grammar/analyzer.cc
@@ -0,0 +1,81 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "utils/grammar/analyzer.h" + +#include "utils/base/status_macros.h" +#include "utils/utf8/unicodetext.h" + +namespace libtextclassifier3::grammar { + +Analyzer::Analyzer(const UniLib* unilib, const RulesSet* rules_set) + // TODO(smillius): Add tokenizer options to `RulesSet`. + : owned_tokenizer_(new Tokenizer(libtextclassifier3::TokenizationType_ICU, + unilib, + /*codepoint_ranges=*/{}, + /*internal_tokenizer_codepoint_ranges=*/{}, + /*split_on_script_change=*/false, + /*icu_preserve_whitespace_tokens=*/false)), + tokenizer_(owned_tokenizer_.get()), + parser_(unilib, rules_set), + semantic_evaluator_(rules_set->semantic_values_schema() != nullptr + ? flatbuffers::GetRoot<reflection::Schema>( + rules_set->semantic_values_schema()->data()) + : nullptr) {} + +Analyzer::Analyzer(const UniLib* unilib, const RulesSet* rules_set, + const Tokenizer* tokenizer) + : tokenizer_(tokenizer), + parser_(unilib, rules_set), + semantic_evaluator_(rules_set->semantic_values_schema() != nullptr + ? flatbuffers::GetRoot<reflection::Schema>( + rules_set->semantic_values_schema()->data()) + : nullptr) {} + +StatusOr<std::vector<EvaluatedDerivation>> Analyzer::Parse( + const TextContext& input, UnsafeArena* arena) const { + std::vector<EvaluatedDerivation> result; + + // Evaluate each derivation. + for (const Derivation& derivation : + ValidDeduplicatedDerivations(parser_.Parse(input, arena))) { + TC3_ASSIGN_OR_RETURN(const SemanticValue* value, + semantic_evaluator_.Eval(input, derivation, arena)); + result.emplace_back(EvaluatedDerivation{std::move(derivation), value}); + } + + return result; +} + +StatusOr<std::vector<EvaluatedDerivation>> Analyzer::Parse( + const UnicodeText& text, const std::vector<Locale>& locales, + UnsafeArena* arena) const { + return Parse(BuildTextContextForInput(text, locales), arena); +} + +TextContext Analyzer::BuildTextContextForInput( + const UnicodeText& text, const std::vector<Locale>& locales) const { + TextContext context; + context.text = UnicodeText(text, /*do_copy=*/false); + context.tokens = tokenizer_->Tokenize(context.text); + context.codepoints = context.text.Codepoints(); + context.codepoints.push_back(context.text.end()); + context.locales = locales; + context.context_span.first = 0; + context.context_span.second = context.tokens.size(); + return context; +} + +} // namespace libtextclassifier3::grammar diff --git a/utils/grammar/analyzer.h b/utils/grammar/analyzer.h new file mode 100644 index 0000000..f3be919 --- /dev/null +++ b/utils/grammar/analyzer.h
@@ -0,0 +1,61 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_ANALYZER_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_ANALYZER_H_ + +#include "utils/base/arena.h" +#include "utils/base/statusor.h" +#include "utils/grammar/evaluated-derivation.h" +#include "utils/grammar/parsing/parser.h" +#include "utils/grammar/semantics/composer.h" +#include "utils/grammar/text-context.h" +#include "utils/i18n/locale.h" +#include "utils/tokenizer.h" +#include "utils/utf8/unilib.h" + +namespace libtextclassifier3::grammar { + +// An analyzer that parses and semantically evaluates an input text with a +// grammar. +class Analyzer { + public: + explicit Analyzer(const UniLib* unilib, const RulesSet* rules_set); + explicit Analyzer(const UniLib* unilib, const RulesSet* rules_set, + const Tokenizer* tokenizer); + + // Parses and evaluates an input. + StatusOr<std::vector<EvaluatedDerivation>> Parse(const TextContext& input, + UnsafeArena* arena) const; + StatusOr<std::vector<EvaluatedDerivation>> Parse( + const UnicodeText& text, const std::vector<Locale>& locales, + UnsafeArena* arena) const; + + // Pre-processes an input text for parsing. + TextContext BuildTextContextForInput( + const UnicodeText& text, const std::vector<Locale>& locales = {}) const; + + const Parser& parser() const { return parser_; } + + private: + std::unique_ptr<Tokenizer> owned_tokenizer_; + const Tokenizer* tokenizer_; + Parser parser_; + SemanticComposer semantic_evaluator_; +}; + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_ANALYZER_H_ diff --git a/utils/grammar/callback-delegate.h b/utils/grammar/callback-delegate.h deleted file mode 100644 index 54eca93..0000000 --- a/utils/grammar/callback-delegate.h +++ /dev/null
@@ -1,44 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#pragma GCC diagnostic ignored "-Wc++17-extensions" - -#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_CALLBACK_DELEGATE_H_ -#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_CALLBACK_DELEGATE_H_ - -#include "utils/base/integral_types.h" -#include "utils/grammar/match.h" -#include "utils/grammar/rules_generated.h" -#include "utils/grammar/types.h" - -namespace libtextclassifier3::grammar { - -class Matcher; - -// CallbackDelegate is an interface and default implementation used by the -// grammar matcher to dispatch rule matches. -class CallbackDelegate { - public: - virtual ~CallbackDelegate() = default; - - // This is called by the matcher whenever it finds a match for a rule to - // which a callback is attached. - virtual void MatchFound(const Match* match, const CallbackId callback_id, - const int64 callback_param, Matcher* matcher) {} -}; - -} // namespace libtextclassifier3::grammar - -#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_CALLBACK_DELEGATE_H_ diff --git a/utils/grammar/evaluated-derivation.h b/utils/grammar/evaluated-derivation.h new file mode 100644 index 0000000..b4723f6 --- /dev/null +++ b/utils/grammar/evaluated-derivation.h
@@ -0,0 +1,32 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_EVALUATED_DERIVATION_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_EVALUATED_DERIVATION_H_ + +#include "utils/grammar/parsing/derivation.h" +#include "utils/grammar/semantics/value.h" + +namespace libtextclassifier3::grammar { + +// A parse tree for a root rule and its semantic value. +struct EvaluatedDerivation { + Derivation derivation; + const SemanticValue* value; +}; + +}; // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_EVALUATED_DERIVATION_H_ diff --git a/utils/grammar/lexer.cc b/utils/grammar/lexer.cc deleted file mode 100644 index 75c63f4..0000000 --- a/utils/grammar/lexer.cc +++ /dev/null
@@ -1,320 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#include "utils/grammar/lexer.h" - -#include <unordered_map> - -#include "annotator/types.h" -#include "utils/zlib/tclib_zlib.h" -#include "utils/zlib/zlib_regex.h" - -namespace libtextclassifier3::grammar { -namespace { - -inline bool CheckMemoryUsage(const Matcher* matcher) { - // The maximum memory usage for matching. - constexpr int kMaxMemoryUsage = 1 << 20; - return matcher->ArenaSize() <= kMaxMemoryUsage; -} - -Match* CheckedAddMatch(const Nonterm nonterm, - const CodepointSpan codepoint_span, - const int match_offset, const int16 type, - Matcher* matcher) { - if (nonterm == kUnassignedNonterm || !CheckMemoryUsage(matcher)) { - return nullptr; - } - return matcher->AllocateAndInitMatch<Match>(nonterm, codepoint_span, - match_offset, type); -} - -void CheckedEmit(const Nonterm nonterm, const CodepointSpan codepoint_span, - const int match_offset, int16 type, Matcher* matcher) { - if (nonterm != kUnassignedNonterm && CheckMemoryUsage(matcher)) { - matcher->AddMatch(matcher->AllocateAndInitMatch<Match>( - nonterm, codepoint_span, match_offset, type)); - } -} - -int MapCodepointToTokenPaddingIfPresent( - const std::unordered_map<CodepointIndex, CodepointIndex>& token_alignment, - const int start) { - const auto it = token_alignment.find(start); - if (it != token_alignment.end()) { - return it->second; - } - return start; -} - -} // namespace - -Lexer::Lexer(const UniLib* unilib, const RulesSet* rules) - : unilib_(*unilib), - rules_(rules), - regex_annotators_(BuildRegexAnnotator(unilib_, rules)) {} - -std::vector<Lexer::RegexAnnotator> Lexer::BuildRegexAnnotator( - const UniLib& unilib, const RulesSet* rules) const { - std::vector<Lexer::RegexAnnotator> result; - if (rules->regex_annotator() != nullptr) { - std::unique_ptr<ZlibDecompressor> decompressor = - ZlibDecompressor::Instance(); - result.reserve(rules->regex_annotator()->size()); - for (const RulesSet_::RegexAnnotator* regex_annotator : - *rules->regex_annotator()) { - result.push_back( - {UncompressMakeRegexPattern(unilib_, regex_annotator->pattern(), - regex_annotator->compressed_pattern(), - rules->lazy_regex_compilation(), - decompressor.get()), - regex_annotator->nonterminal()}); - } - } - return result; -} - -void Lexer::Emit(const Symbol& symbol, const RulesSet_::Nonterminals* nonterms, - Matcher* matcher) const { - switch (symbol.type) { - case Symbol::Type::TYPE_MATCH: { - // Just emit the match. - matcher->AddMatch(symbol.match); - return; - } - case Symbol::Type::TYPE_DIGITS: { - // Emit <digits> if used by the rules. - CheckedEmit(nonterms->digits_nt(), symbol.codepoint_span, - symbol.match_offset, Match::kDigitsType, matcher); - - // Emit <n_digits> if used by the rules. - if (nonterms->n_digits_nt() != nullptr) { - const int num_digits = - symbol.codepoint_span.second - symbol.codepoint_span.first; - if (num_digits <= nonterms->n_digits_nt()->size()) { - CheckedEmit(nonterms->n_digits_nt()->Get(num_digits - 1), - symbol.codepoint_span, symbol.match_offset, - Match::kDigitsType, matcher); - } - } - break; - } - case Symbol::Type::TYPE_TERM: { - // Emit <uppercase_token> if used by the rules. - if (nonterms->uppercase_token_nt() != 0 && - unilib_.IsUpperText( - UTF8ToUnicodeText(symbol.lexeme, /*do_copy=*/false))) { - CheckedEmit(nonterms->uppercase_token_nt(), symbol.codepoint_span, - symbol.match_offset, Match::kTokenType, matcher); - } - break; - } - default: - break; - } - - // Emit the token as terminal. - if (CheckMemoryUsage(matcher)) { - matcher->AddTerminal(symbol.codepoint_span, symbol.match_offset, - symbol.lexeme); - } - - // Emit <token> if used by rules. - CheckedEmit(nonterms->token_nt(), symbol.codepoint_span, symbol.match_offset, - Match::kTokenType, matcher); -} - -Lexer::Symbol::Type Lexer::GetSymbolType( - const UnicodeText::const_iterator& it) const { - if (unilib_.IsPunctuation(*it)) { - return Symbol::Type::TYPE_PUNCTUATION; - } else if (unilib_.IsDigit(*it)) { - return Symbol::Type::TYPE_DIGITS; - } else { - return Symbol::Type::TYPE_TERM; - } -} - -void Lexer::ProcessToken(const StringPiece value, const int prev_token_end, - const CodepointSpan codepoint_span, - std::vector<Lexer::Symbol>* symbols) const { - // Possibly split token. - UnicodeText token_unicode = UTF8ToUnicodeText(value.data(), value.size(), - /*do_copy=*/false); - int last_end = prev_token_end; - auto token_end = token_unicode.end(); - auto it = token_unicode.begin(); - Symbol::Type type = GetSymbolType(it); - CodepointIndex sub_token_start = codepoint_span.first; - while (it != token_end) { - auto next = std::next(it); - int num_codepoints = 1; - Symbol::Type next_type; - while (next != token_end) { - next_type = GetSymbolType(next); - if (type == Symbol::Type::TYPE_PUNCTUATION || next_type != type) { - break; - } - ++next; - ++num_codepoints; - } - symbols->push_back(Symbol{ - type, CodepointSpan{sub_token_start, sub_token_start + num_codepoints}, - /*match_offset=*/last_end, - /*lexeme=*/ - StringPiece(it.utf8_data(), next.utf8_data() - it.utf8_data())}); - last_end = sub_token_start + num_codepoints; - it = next; - type = next_type; - sub_token_start = last_end; - } -} - -void Lexer::Process(const UnicodeText& text, const std::vector<Token>& tokens, - const std::vector<AnnotatedSpan>* annotations, - Matcher* matcher) const { - return Process(text, tokens.begin(), tokens.end(), annotations, matcher); -} - -void Lexer::Process(const UnicodeText& text, - const std::vector<Token>::const_iterator& begin, - const std::vector<Token>::const_iterator& end, - const std::vector<AnnotatedSpan>* annotations, - Matcher* matcher) const { - if (begin == end) { - return; - } - - const RulesSet_::Nonterminals* nonterminals = rules_->nonterminals(); - - // Initialize processing of new text. - CodepointIndex prev_token_end = 0; - std::vector<Symbol> symbols; - matcher->Reset(); - - // The matcher expects the terminals and non-terminals it received to be in - // non-decreasing end-position order. The sorting above makes sure the - // pre-defined matches adhere to that order. - // Ideally, we would just have to emit a predefined match whenever we see that - // the next token we feed would be ending later. - // But as we implicitly ignore whitespace, we have to merge preceding - // whitespace to the match start so that tokens and non-terminals fed appear - // as next to each other without whitespace. - // We keep track of real token starts and precending whitespace in - // `token_match_start`, so that we can extend a predefined match's start to - // include the preceding whitespace. - std::unordered_map<CodepointIndex, CodepointIndex> token_match_start; - - // Add start symbols. - if (Match* match = - CheckedAddMatch(nonterminals->start_nt(), CodepointSpan{0, 0}, - /*match_offset=*/0, Match::kBreakType, matcher)) { - symbols.push_back(Symbol(match)); - } - if (Match* match = - CheckedAddMatch(nonterminals->wordbreak_nt(), CodepointSpan{0, 0}, - /*match_offset=*/0, Match::kBreakType, matcher)) { - symbols.push_back(Symbol(match)); - } - - for (auto token_it = begin; token_it != end; token_it++) { - const Token& token = *token_it; - - // Record match starts for token boundaries, so that we can snap pre-defined - // matches to it. - if (prev_token_end != token.start) { - token_match_start[token.start] = prev_token_end; - } - - ProcessToken(token.value, - /*prev_token_end=*/prev_token_end, - CodepointSpan{token.start, token.end}, &symbols); - prev_token_end = token.end; - - // Add word break symbol if used by the grammar. - if (Match* match = CheckedAddMatch( - nonterminals->wordbreak_nt(), CodepointSpan{token.end, token.end}, - /*match_offset=*/token.end, Match::kBreakType, matcher)) { - symbols.push_back(Symbol(match)); - } - } - - // Add end symbol if used by the grammar. - if (Match* match = CheckedAddMatch( - nonterminals->end_nt(), CodepointSpan{prev_token_end, prev_token_end}, - /*match_offset=*/prev_token_end, Match::kBreakType, matcher)) { - symbols.push_back(Symbol(match)); - } - - // Add matches based on annotations. - auto annotation_nonterminals = nonterminals->annotation_nt(); - if (annotation_nonterminals != nullptr && annotations != nullptr) { - for (const AnnotatedSpan& annotated_span : *annotations) { - const ClassificationResult& classification = - annotated_span.classification.front(); - if (auto entry = annotation_nonterminals->LookupByKey( - classification.collection.c_str())) { - AnnotationMatch* match = matcher->AllocateAndInitMatch<AnnotationMatch>( - entry->value(), annotated_span.span, - /*match_offset=*/ - MapCodepointToTokenPaddingIfPresent(token_match_start, - annotated_span.span.first), - Match::kAnnotationMatch); - match->annotation = &classification; - symbols.push_back(Symbol(match)); - } - } - } - - // Add regex annotator matches for the range covered by the tokens. - for (const RegexAnnotator& regex_annotator : regex_annotators_) { - std::unique_ptr<UniLib::RegexMatcher> regex_matcher = - regex_annotator.pattern->Matcher(UnicodeText::Substring( - text, begin->start, prev_token_end, /*do_copy=*/false)); - int status = UniLib::RegexMatcher::kNoError; - while (regex_matcher->Find(&status) && - status == UniLib::RegexMatcher::kNoError) { - const CodepointSpan span = { - regex_matcher->Start(0, &status) + begin->start, - regex_matcher->End(0, &status) + begin->start}; - if (Match* match = - CheckedAddMatch(regex_annotator.nonterm, span, /*match_offset=*/ - MapCodepointToTokenPaddingIfPresent( - token_match_start, span.first), - Match::kUnknownType, matcher)) { - symbols.push_back(Symbol(match)); - } - } - } - - std::sort(symbols.begin(), symbols.end(), - [](const Symbol& a, const Symbol& b) { - // Sort by increasing (end, start) position to guarantee the - // matcher requirement that the tokens are fed in non-decreasing - // end position order. - return std::tie(a.codepoint_span.second, a.codepoint_span.first) < - std::tie(b.codepoint_span.second, b.codepoint_span.first); - }); - - // Emit symbols to matcher. - for (const Symbol& symbol : symbols) { - Emit(symbol, nonterminals, matcher); - } - - // Finish the matching. - matcher->Finish(); -} - -} // namespace libtextclassifier3::grammar diff --git a/utils/grammar/lexer.h b/utils/grammar/lexer.h deleted file mode 100644 index 6ca5f08..0000000 --- a/utils/grammar/lexer.h +++ /dev/null
@@ -1,177 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#pragma GCC diagnostic ignored "-Wc++17-extensions" - -// This is a lexer that runs off the tokenizer and outputs the tokens to a -// grammar matcher. The tokens it forwards are the same as the ones produced -// by the tokenizer, but possibly further split and normalized (downcased). -// Examples: -// -// - single character tokens for punctuation (e.g., AddTerminal("?")) -// -// - a string of letters (e.g., "Foo" -- it calls AddTerminal() on "foo") -// -// - a string of digits (e.g., AddTerminal("37")) -// -// In addition to the terminal tokens above, it also outputs certain -// special nonterminals: -// -// - a <token> nonterminal, which it outputs in addition to the -// regular AddTerminal() call for every token -// -// - a <digits> nonterminal, which it outputs in addition to -// the regular AddTerminal() call for each string of digits -// -// - <N_digits> nonterminals, where N is the length of the string of -// digits. By default the maximum N that will be output is 20. This -// may be changed at compile time by kMaxNDigitsLength. For instance, -// "123" will produce a <3_digits> nonterminal, "1234567" will produce -// a <7_digits> nonterminal. -// -// It does not output any whitespace. Instead, whitespace gets absorbed into -// the token that follows them in the text. -// For example, if the text contains: -// -// ...hello there world... -// | | | -// offset=16 39 52 -// -// then the output will be: -// -// "hello" [?, 16) -// "there" [16, 44) <-- note "16" NOT "39" -// "world" [44, ?) <-- note "44" NOT "52" -// -// This makes it appear to the Matcher as if the tokens are adjacent -- so -// whitespace is simply ignored. -// -// A minor optimization: We don't bother to output nonterminals if the grammar -// rules don't reference them. - -#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_LEXER_H_ -#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_LEXER_H_ - -#include "annotator/types.h" -#include "utils/grammar/matcher.h" -#include "utils/grammar/rules_generated.h" -#include "utils/grammar/types.h" -#include "utils/strings/stringpiece.h" -#include "utils/utf8/unicodetext.h" -#include "utils/utf8/unilib.h" - -namespace libtextclassifier3::grammar { - -class Lexer { - public: - explicit Lexer(const UniLib* unilib, const RulesSet* rules); - - // Processes a tokenized text. Classifies the tokens and feeds them to the - // matcher. - // The provided annotations will be fed to the matcher alongside the tokens. - // NOTE: The `annotations` need to outlive any dependent processing. - void Process(const UnicodeText& text, const std::vector<Token>& tokens, - const std::vector<AnnotatedSpan>* annotations, - Matcher* matcher) const; - void Process(const UnicodeText& text, - const std::vector<Token>::const_iterator& begin, - const std::vector<Token>::const_iterator& end, - const std::vector<AnnotatedSpan>* annotations, - Matcher* matcher) const; - - private: - // A lexical symbol with an identified meaning that represents raw tokens, - // token categories or predefined text matches. - // It is the unit fed to the grammar matcher. - struct Symbol { - // The type of the lexical symbol. - enum class Type { - // A raw token. - TYPE_TERM, - - // A symbol representing a string of digits. - TYPE_DIGITS, - - // Punctuation characters. - TYPE_PUNCTUATION, - - // A predefined match. - TYPE_MATCH - }; - - explicit Symbol() = default; - - // Constructs a symbol of a given type with an anchor in the text. - Symbol(const Type type, const CodepointSpan codepoint_span, - const int match_offset, StringPiece lexeme) - : type(type), - codepoint_span(codepoint_span), - match_offset(match_offset), - lexeme(lexeme) {} - - // Constructs a symbol from a pre-defined match. - explicit Symbol(Match* match) - : type(Type::TYPE_MATCH), - codepoint_span(match->codepoint_span), - match_offset(match->match_offset), - match(match) {} - - // The type of the symbole. - Type type; - - // The span in the text as codepoint offsets. - CodepointSpan codepoint_span; - - // The match start offset (including preceding whitespace) as codepoint - // offset. - int match_offset; - - // The symbol text value. - StringPiece lexeme; - - // The predefined match. - Match* match; - }; - - // Processes a single token: the token is split and classified into symbols. - void ProcessToken(const StringPiece value, const int prev_token_end, - const CodepointSpan codepoint_span, - std::vector<Symbol>* symbols) const; - - // Emits a token to the matcher. - void Emit(const Symbol& symbol, const RulesSet_::Nonterminals* nonterms, - Matcher* matcher) const; - - // Gets the type of a character. - Symbol::Type GetSymbolType(const UnicodeText::const_iterator& it) const; - - private: - struct RegexAnnotator { - std::unique_ptr<UniLib::RegexPattern> pattern; - Nonterm nonterm; - }; - - // Uncompress and build the defined regex annotators. - std::vector<RegexAnnotator> BuildRegexAnnotator(const UniLib& unilib, - const RulesSet* rules) const; - - const UniLib& unilib_; - const RulesSet* rules_; - std::vector<RegexAnnotator> regex_annotators_; -}; - -} // namespace libtextclassifier3::grammar - -#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_LEXER_H_ diff --git a/utils/grammar/match.cc b/utils/grammar/match.cc deleted file mode 100644 index 2c6452e..0000000 --- a/utils/grammar/match.cc +++ /dev/null
@@ -1,76 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#include "utils/grammar/match.h" - -#include <algorithm> -#include <stack> - -namespace libtextclassifier3::grammar { - -void Traverse(const Match* root, - const std::function<bool(const Match*)>& node_fn) { - std::stack<const Match*> open; - open.push(root); - - while (!open.empty()) { - const Match* node = open.top(); - open.pop(); - if (!node_fn(node) || node->IsLeaf()) { - continue; - } - open.push(node->rhs2); - if (node->rhs1 != nullptr) { - open.push(node->rhs1); - } - } -} - -const Match* SelectFirst(const Match* root, - const std::function<bool(const Match*)>& pred_fn) { - std::stack<const Match*> open; - open.push(root); - - while (!open.empty()) { - const Match* node = open.top(); - open.pop(); - if (pred_fn(node)) { - return node; - } - if (node->IsLeaf()) { - continue; - } - open.push(node->rhs2); - if (node->rhs1 != nullptr) { - open.push(node->rhs1); - } - } - - return nullptr; -} - -std::vector<const Match*> SelectAll( - const Match* root, const std::function<bool(const Match*)>& pred_fn) { - std::vector<const Match*> result; - Traverse(root, [&result, pred_fn](const Match* node) { - if (pred_fn(node)) { - result.push_back(node); - } - return true; - }); - return result; -} - -} // namespace libtextclassifier3::grammar diff --git a/utils/grammar/match.h b/utils/grammar/match.h deleted file mode 100644 index a485e62..0000000 --- a/utils/grammar/match.h +++ /dev/null
@@ -1,173 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#pragma GCC diagnostic ignored "-Wc++17-extensions" - -#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCH_H_ -#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCH_H_ - -#include <functional> -#include <vector> - -#include "annotator/types.h" -#include "utils/grammar/types.h" -#include "utils/strings/stringpiece.h" - -namespace libtextclassifier3::grammar { - -// Represents a single match that was found for a particular nonterminal. -// Instances should be created by calling Matcher::AllocateMatch(). -// This uses an arena to allocate matches (and subclasses thereof). -struct Match { - static constexpr int16 kUnknownType = 0; - static constexpr int16 kTokenType = -1; - static constexpr int16 kDigitsType = -2; - static constexpr int16 kBreakType = -3; - static constexpr int16 kAssertionMatch = -4; - static constexpr int16 kMappingMatch = -5; - static constexpr int16 kExclusionMatch = -6; - static constexpr int16 kAnnotationMatch = -7; - - void Init(const Nonterm arg_lhs, const CodepointSpan arg_codepoint_span, - const int arg_match_offset, const int arg_type = kUnknownType) { - lhs = arg_lhs; - codepoint_span = arg_codepoint_span; - match_offset = arg_match_offset; - type = arg_type; - rhs1 = nullptr; - rhs2 = nullptr; - } - - void Init(const Match& other) { *this = other; } - - // For binary rule matches: rhs1 != NULL and rhs2 != NULL - // unary rule matches: rhs1 == NULL and rhs2 != NULL - // terminal rule matches: rhs1 != NULL and rhs2 == NULL - // custom leaves: rhs1 == NULL and rhs2 == NULL - bool IsInteriorNode() const { return rhs2 != nullptr; } - bool IsLeaf() const { return !rhs2; } - - bool IsBinaryRule() const { return rhs1 && rhs2; } - bool IsUnaryRule() const { return !rhs1 && rhs2; } - bool IsTerminalRule() const { return rhs1 && !rhs2; } - bool HasLeadingWhitespace() const { - return codepoint_span.first != match_offset; - } - - const Match* unary_rule_rhs() const { return rhs2; } - - // Used in singly-linked queue of matches for processing. - Match* next = nullptr; - - // Nonterminal we found a match for. - Nonterm lhs = kUnassignedNonterm; - - // Type of the match. - int16 type = kUnknownType; - - // The span in codepoints. - CodepointSpan codepoint_span = CodepointSpan::kInvalid; - - // The begin codepoint offset used during matching. - // This is usually including any prefix whitespace. - int match_offset; - - union { - // The first sub match for binary rules. - const Match* rhs1 = nullptr; - - // The terminal, for terminal rules. - const char* terminal; - }; - // First or second sub-match for interior nodes. - const Match* rhs2 = nullptr; -}; - -// Match type to keep track of associated values. -struct MappingMatch : public Match { - // The associated id or value. - int64 id; -}; - -// Match type to keep track of assertions. -struct AssertionMatch : public Match { - // If true, the assertion is negative and will be valid if the input doesn't - // match. - bool negative; -}; - -// Match type to define exclusions. -struct ExclusionMatch : public Match { - // The nonterminal that denotes matches to exclude from a successful match. - // So the match is only valid if there is no match of `exclusion_nonterm` - // spanning the same text range. - Nonterm exclusion_nonterm; -}; - -// Match to represent an annotator annotated span in the grammar. -struct AnnotationMatch : public Match { - const ClassificationResult* annotation; -}; - -// Utility functions for parse tree traversal. - -// Does a preorder traversal, calling `node_fn` on each node. -// `node_fn` is expected to return whether to continue expanding a node. -void Traverse(const Match* root, - const std::function<bool(const Match*)>& node_fn); - -// Does a preorder traversal, calling `pred_fn` and returns the first node -// on which `pred_fn` returns true. -const Match* SelectFirst(const Match* root, - const std::function<bool(const Match*)>& pred_fn); - -// Does a preorder traversal, selecting all nodes where `pred_fn` returns true. -std::vector<const Match*> SelectAll( - const Match* root, const std::function<bool(const Match*)>& pred_fn); - -// Selects all terminals from a parse tree. -inline std::vector<const Match*> SelectTerminals(const Match* root) { - return SelectAll(root, &Match::IsTerminalRule); -} - -// Selects all leaves from a parse tree. -inline std::vector<const Match*> SelectLeaves(const Match* root) { - return SelectAll(root, &Match::IsLeaf); -} - -// Retrieves the first child node of a given type. -template <typename T> -const T* SelectFirstOfType(const Match* root, const int16 type) { - return static_cast<const T*>(SelectFirst( - root, [type](const Match* node) { return node->type == type; })); -} - -// Retrieves all nodes of a given type. -template <typename T> -const std::vector<const T*> SelectAllOfType(const Match* root, - const int16 type) { - std::vector<const T*> result; - Traverse(root, [&result, type](const Match* node) { - if (node->type == type) { - result.push_back(static_cast<const T*>(node)); - } - return true; - }); - return result; -} - -} // namespace libtextclassifier3::grammar - -#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCH_H_ diff --git a/utils/grammar/matcher.h b/utils/grammar/matcher.h deleted file mode 100644 index 1fdad84..0000000 --- a/utils/grammar/matcher.h +++ /dev/null
@@ -1,247 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -#pragma GCC diagnostic ignored "-Wc++17-extensions" - -// A token matcher based on context-free grammars. -// -// A lexer passes token to the matcher: literal terminal strings and token -// types. It passes tokens to the matcher by calling AddTerminal() and -// AddMatch() for literal terminals and token types, respectively. -// The lexer passes each token along with the [begin, end) position range -// in which it occurs. So for an input string "Groundhog February 2, 2007", the -// lexer would tell the matcher that: -// -// "Groundhog" occurs at [0, 9) -// <space> occurs at [9, 10) -// "February" occurs at [10, 18) -// <space> occurs at [18, 19) -// <string_of_digits> occurs at [19, 20) -// "," occurs at [20, 21) -// <space> occurs at [21, 22) -// <string_of_digits> occurs at [22, 26) -// -// The lexer passes tokens to the matcher by calling AddTerminal() and -// AddMatch() for literal terminals and token types, respectively. -// -// Although it is unnecessary for this example grammar, a lexer can -// output multiple tokens for the same input range. So our lexer could -// additionally output: -// "2" occurs at [19, 20) // a second token for [19, 20) -// "2007" occurs at [22, 26) -// <syllable> occurs at [0, 6) // overlaps with (Groundhog [0, 9)) -// <syllable> occurs at [6, 9) -// The only constraint on the lexer's output is that it has to pass tokens -// to the matcher in left-to-right order, strictly speaking, their "end" -// positions must be nondecreasing. (This constraint allows a more -// efficient matching algorithm.) The "begin" positions can be in any -// order. -// -// There are two kinds of supported callbacks: -// (1) OUTPUT: Callbacks are the only output mechanism a matcher has. For each -// "top-level" rule in your grammar, like the rule for <date> above -- something -// you're trying to find instances of -- you use a callback which the matcher -// will invoke every time it finds an instance of <date>. -// (2) FILTERS: -// Callbacks allow you to put extra conditions on when a grammar rule -// applies. In the example grammar, the rule -// -// <day> ::= <string_of_digits> // must be between 1 and 31 -// -// should only apply for *some* <string_of_digits> tokens, not others. -// By using a filter callback on this rule, you can tell the matcher that -// an instance of the rule's RHS is only *sometimes* considered an -// instance of its LHS. The filter callback will get invoked whenever -// the matcher finds an instance of <string_of_digits>. The callback can -// look at the digits and decide whether they represent a number between -// 1 and 31. If so, the callback calls Matcher::AddMatch() to tell the -// matcher there's a <day> there. If not, the callback simply exits -// without calling AddMatch(). -// -// Technically, a FILTER callback can make any number of calls to -// AddMatch() or even AddTerminal(). But the expected usage is to just -// make zero or one call to AddMatch(). OUTPUT callbacks are not expected -// to call either of these -- output callbacks are invoked merely as a -// side-effect, not in order to decide whether a rule applies or not. -// -// In the above example, you would probably use three callbacks. Filter -// callbacks on the rules for <day> and <year> would check the numeric -// value of the <string_of_digits>. An output callback on the rule for -// <date> would simply increment the counter of dates found on the page. -// -// Note that callbacks are attached to rules, not to nonterminals. You -// could have two alternative rules for <date> and use a different -// callback for each one. - -#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCHER_H_ -#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCHER_H_ - -#include <array> -#include <functional> -#include <vector> - -#include "annotator/types.h" -#include "utils/base/arena.h" -#include "utils/grammar/callback-delegate.h" -#include "utils/grammar/match.h" -#include "utils/grammar/rules_generated.h" -#include "utils/strings/stringpiece.h" -#include "utils/utf8/unilib.h" - -namespace libtextclassifier3::grammar { - -class Matcher { - public: - explicit Matcher(const UniLib* unilib, const RulesSet* rules, - const std::vector<const RulesSet_::Rules*> rules_shards, - CallbackDelegate* delegate) - : state_(STATE_DEFAULT), - unilib_(*unilib), - arena_(kBlocksize), - rules_(rules), - rules_shards_(rules_shards), - delegate_(delegate) { - TC3_CHECK(rules_ != nullptr); - Reset(); - } - explicit Matcher(const UniLib* unilib, const RulesSet* rules, - CallbackDelegate* delegate) - : Matcher(unilib, rules, {}, delegate) { - rules_shards_.reserve(rules->rules()->size()); - rules_shards_.insert(rules_shards_.end(), rules->rules()->begin(), - rules->rules()->end()); - } - - // Resets the matcher. - void Reset(); - - // Finish the matching. - void Finish(); - - // Tells the matcher that the given terminal was found occupying position - // range [begin, end) in the input. - // The matcher may invoke callback functions before returning, if this - // terminal triggers any new matches for rules in the grammar. - // Calls to AddTerminal() and AddMatch() must be in left-to-right order, - // that is, the sequence of `end` values must be non-decreasing. - void AddTerminal(const CodepointSpan codepoint_span, const int match_offset, - StringPiece terminal); - void AddTerminal(const CodepointIndex begin, const CodepointIndex end, - StringPiece terminal) { - AddTerminal(CodepointSpan{begin, end}, begin, terminal); - } - - // Adds a nonterminal match to the chart. - // This can be invoked by the lexer if the lexer needs to add nonterminals to - // the chart. - void AddMatch(Match* match); - - // Allocates memory from an area for a new match. - // The `size` parameter is there to allow subclassing of the match object - // with additional fields. - Match* AllocateMatch(const size_t size) { - return reinterpret_cast<Match*>(arena_.Alloc(size)); - } - - template <typename T> - T* AllocateMatch() { - return reinterpret_cast<T*>(arena_.Alloc(sizeof(T))); - } - - template <typename T, typename... Args> - T* AllocateAndInitMatch(Args... args) { - T* match = AllocateMatch<T>(); - match->Init(args...); - return match; - } - - // Returns the current number of bytes allocated for all match objects. - size_t ArenaSize() const { return arena_.status().bytes_allocated(); } - - private: - static constexpr int kBlocksize = 16 << 10; - - // The state of the matcher. - enum State { - // The matcher is in the default state. - STATE_DEFAULT = 0, - - // The matcher is currently processing queued match items. - STATE_PROCESSING = 1, - }; - State state_; - - // Process matches from lhs set. - void ExecuteLhsSet(const CodepointSpan codepoint_span, const int match_offset, - const int whitespace_gap, - const std::function<void(Match*)>& initializer, - const RulesSet_::LhsSet* lhs_set, - CallbackDelegate* delegate); - - // Queues a newly created match item. - void QueueForProcessing(Match* item); - - // Queues a match item for later post checking of the exclusion condition. - // For exclusions we need to check that the `item->excluded_nonterminal` - // doesn't match the same span. As we cannot know which matches have already - // been added, we queue the item for later post checking - once all matches - // up to `item->codepoint_span.second` have been added. - void QueueForPostCheck(ExclusionMatch* item); - - // Adds pending items to the chart, possibly generating new matches as a - // result. - void ProcessPendingSet(); - - // Returns whether the chart contains a match for a given nonterminal. - bool ContainsMatch(const Nonterm nonterm, const CodepointSpan& span) const; - - // Checks all pending exclusion matches that their exclusion condition is - // fulfilled. - void ProcessPendingExclusionMatches(); - - UniLib unilib_; - - // Memory arena for match allocation. - UnsafeArena arena_; - - // The end position of the most recent match or terminal, for sanity - // checking. - int last_end_; - - // Rules. - const RulesSet* rules_; - - // The set of items pending to be added to the chart as a singly-linked list. - Match* pending_items_; - - // The set of items pending to be post-checked as a singly-linked list. - ExclusionMatch* pending_exclusion_items_; - - // The chart data structure: a hashtable containing all matches, indexed by - // their end positions. - static constexpr int kChartHashTableNumBuckets = 1 << 8; - static constexpr int kChartHashTableBitmask = kChartHashTableNumBuckets - 1; - std::array<Match*, kChartHashTableNumBuckets> chart_; - - // The active rule shards. - std::vector<const RulesSet_::Rules*> rules_shards_; - - // The callback handler. - CallbackDelegate* delegate_; -}; - -} // namespace libtextclassifier3::grammar - -#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_MATCHER_H_ diff --git a/utils/grammar/parsing/chart.h b/utils/grammar/parsing/chart.h new file mode 100644 index 0000000..1d4aa55 --- /dev/null +++ b/utils/grammar/parsing/chart.h
@@ -0,0 +1,107 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_CHART_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_CHART_H_ + +#include <array> + +#include "annotator/types.h" +#include "utils/grammar/parsing/derivation.h" +#include "utils/grammar/parsing/parse-tree.h" + +namespace libtextclassifier3::grammar { + +// Chart is a hashtable container for use with a CYK style parser. +// The hashtable contains all matches, indexed by their end positions. +template <int NumBuckets = 1 << 8> +class Chart { + public: + explicit Chart() { std::fill(chart_.begin(), chart_.end(), nullptr); } + + // Iterator that allows iterating through recorded matches that end at a given + // match offset. + class Iterator { + public: + explicit Iterator(const int match_offset, const ParseTree* value) + : match_offset_(match_offset), value_(value) {} + + bool Done() const { + return value_ == nullptr || + (value_->codepoint_span.second < match_offset_); + } + const ParseTree* Item() const { return value_; } + void Next() { + TC3_DCHECK(!Done()); + value_ = value_->next; + } + + private: + const int match_offset_; + const ParseTree* value_; + }; + + // Returns whether the chart contains a match for a given nonterminal. + bool HasMatch(const Nonterm nonterm, const CodepointSpan& span) const; + + // Adds a match to the chart. + void Add(ParseTree* item) { + item->next = chart_[item->codepoint_span.second & kChartHashTableBitmask]; + chart_[item->codepoint_span.second & kChartHashTableBitmask] = item; + } + + // Records a derivation of a root rule. + void AddDerivation(const Derivation& derivation) { + root_derivations_.push_back(derivation); + } + + // Returns an iterator through all matches ending at `match_offset`. + Iterator MatchesEndingAt(const int match_offset) const { + const ParseTree* value = chart_[match_offset & kChartHashTableBitmask]; + // The chain of items is in decreasing `end` order. + // Find the ones that have prev->end == item->begin. + while (value != nullptr && (value->codepoint_span.second > match_offset)) { + value = value->next; + } + return Iterator(match_offset, value); + } + + const std::vector<Derivation> derivations() const { + return root_derivations_; + } + + private: + static constexpr int kChartHashTableBitmask = NumBuckets - 1; + std::array<ParseTree*, NumBuckets> chart_; + std::vector<Derivation> root_derivations_; +}; + +template <int NumBuckets> +bool Chart<NumBuckets>::HasMatch(const Nonterm nonterm, + const CodepointSpan& span) const { + // Lookup by end. + for (Chart<NumBuckets>::Iterator it = MatchesEndingAt(span.second); + !it.Done(); it.Next()) { + if (it.Item()->lhs == nonterm && + it.Item()->codepoint_span.first == span.first) { + return true; + } + } + return false; +} + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_CHART_H_ diff --git a/utils/grammar/parsing/derivation.cc b/utils/grammar/parsing/derivation.cc new file mode 100644 index 0000000..d0c5091 --- /dev/null +++ b/utils/grammar/parsing/derivation.cc
@@ -0,0 +1,100 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "utils/grammar/parsing/derivation.h" + +#include <algorithm> + +namespace libtextclassifier3::grammar { + +bool Derivation::IsValid() const { + bool result = true; + Traverse(parse_tree, [&result](const ParseTree* node) { + if (node->type != ParseTree::Type::kAssertion) { + // Only validation if all checks so far passed. + return result; + } + // Positive assertions are by definition fulfilled, + // fail if the assertion is negative. + if (static_cast<const AssertionNode*>(node)->negative) { + result = false; + } + return result; + }); + return result; +} + +std::vector<Derivation> DeduplicateDerivations( + const std::vector<Derivation>& derivations) { + std::vector<Derivation> sorted_candidates = derivations; + std::stable_sort(sorted_candidates.begin(), sorted_candidates.end(), + [](const Derivation& a, const Derivation& b) { + // Sort by id. + if (a.rule_id != b.rule_id) { + return a.rule_id < b.rule_id; + } + + // Sort by increasing start. + if (a.parse_tree->codepoint_span.first != + b.parse_tree->codepoint_span.first) { + return a.parse_tree->codepoint_span.first < + b.parse_tree->codepoint_span.first; + } + + // Sort by decreasing end. + return a.parse_tree->codepoint_span.second > + b.parse_tree->codepoint_span.second; + }); + + // Deduplicate by overlap. + std::vector<Derivation> result; + for (int i = 0; i < sorted_candidates.size(); i++) { + const Derivation& candidate = sorted_candidates[i]; + bool eliminated = false; + + // Due to the sorting above, the candidate can only be completely + // intersected by a match before it in the sorted order. + for (int j = i - 1; j >= 0; j--) { + if (sorted_candidates[j].rule_id != candidate.rule_id) { + break; + } + if (sorted_candidates[j].parse_tree->codepoint_span.first <= + candidate.parse_tree->codepoint_span.first && + sorted_candidates[j].parse_tree->codepoint_span.second >= + candidate.parse_tree->codepoint_span.second) { + eliminated = true; + break; + } + } + if (!eliminated) { + result.push_back(candidate); + } + } + return result; +} + +std::vector<Derivation> ValidDeduplicatedDerivations( + const std::vector<Derivation>& derivations) { + std::vector<Derivation> result; + for (const Derivation& derivation : DeduplicateDerivations(derivations)) { + // Check that asserts are fulfilled. + if (derivation.IsValid()) { + result.push_back(derivation); + } + } + return result; +} + +} // namespace libtextclassifier3::grammar diff --git a/utils/grammar/parsing/derivation.h b/utils/grammar/parsing/derivation.h new file mode 100644 index 0000000..4994aef --- /dev/null +++ b/utils/grammar/parsing/derivation.h
@@ -0,0 +1,49 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_ + +#include <vector> + +#include "utils/grammar/parsing/parse-tree.h" + +namespace libtextclassifier3::grammar { + +// A parse tree for a root rule. +struct Derivation { + const ParseTree* parse_tree; + int64 rule_id; + + // Checks that all assertions are fulfilled. + bool IsValid() const; +}; + +// Deduplicates rule derivations by containing overlap. +// The grammar system can output multiple candidates for optional parts. +// For example if a rule has an optional suffix, we +// will get two rule derivations when the suffix is present: one with and one +// without the suffix. We therefore deduplicate by containing overlap, viz. from +// two candidates we keep the longer one if it completely contains the shorter. +std::vector<Derivation> DeduplicateDerivations( + const std::vector<Derivation>& derivations); + +// Deduplicates and validates rule derivations. +std::vector<Derivation> ValidDeduplicatedDerivations( + const std::vector<Derivation>& derivations); + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_ diff --git a/utils/grammar/parsing/lexer.cc b/utils/grammar/parsing/lexer.cc new file mode 100644 index 0000000..b87889a --- /dev/null +++ b/utils/grammar/parsing/lexer.cc
@@ -0,0 +1,65 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "utils/grammar/parsing/lexer.h" + +namespace libtextclassifier3::grammar { + +Symbol::Type Lexer::GetSymbolType(const UnicodeText::const_iterator& it) const { + if (unilib_.IsPunctuation(*it)) { + return Symbol::Type::TYPE_PUNCTUATION; + } else if (unilib_.IsDigit(*it)) { + return Symbol::Type::TYPE_DIGITS; + } else { + return Symbol::Type::TYPE_TERM; + } +} + +void Lexer::AppendTokenSymbols(const StringPiece value, int match_offset, + const CodepointSpan codepoint_span, + std::vector<Symbol>* symbols) const { + // Possibly split token. + UnicodeText token_unicode = UTF8ToUnicodeText(value.data(), value.size(), + /*do_copy=*/false); + int next_match_offset = match_offset; + auto token_end = token_unicode.end(); + auto it = token_unicode.begin(); + Symbol::Type type = GetSymbolType(it); + CodepointIndex sub_token_start = codepoint_span.first; + while (it != token_end) { + auto next = std::next(it); + int num_codepoints = 1; + Symbol::Type next_type; + while (next != token_end) { + next_type = GetSymbolType(next); + if (type == Symbol::Type::TYPE_PUNCTUATION || next_type != type) { + break; + } + ++next; + ++num_codepoints; + } + symbols->emplace_back( + type, CodepointSpan{sub_token_start, sub_token_start + num_codepoints}, + /*match_offset=*/next_match_offset, + /*lexeme=*/ + StringPiece(it.utf8_data(), next.utf8_data() - it.utf8_data())); + next_match_offset = sub_token_start + num_codepoints; + it = next; + type = next_type; + sub_token_start = next_match_offset; + } +} + +} // namespace libtextclassifier3::grammar diff --git a/utils/grammar/parsing/lexer.h b/utils/grammar/parsing/lexer.h new file mode 100644 index 0000000..9f13c29 --- /dev/null +++ b/utils/grammar/parsing/lexer.h
@@ -0,0 +1,119 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// A lexer that (splits) and classifies tokens. +// +// Any whitespace gets absorbed into the token that follows them in the text. +// For example, if the text contains: +// +// ...hello there world... +// | | | +// offset=16 39 52 +// +// then the output will be: +// +// "hello" [?, 16) +// "there" [16, 44) <-- note "16" NOT "39" +// "world" [44, ?) <-- note "44" NOT "52" +// +// This makes it appear to the Matcher as if the tokens are adjacent. + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_LEXER_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_LEXER_H_ + +#include <vector> + +#include "annotator/types.h" +#include "utils/grammar/parsing/parse-tree.h" +#include "utils/grammar/types.h" +#include "utils/strings/stringpiece.h" +#include "utils/utf8/unicodetext.h" +#include "utils/utf8/unilib.h" + +namespace libtextclassifier3::grammar { + +// A lexical symbol with an identified meaning that represents raw tokens, +// token categories or predefined text matches. +// It is the unit fed to the grammar matcher. +struct Symbol { + // The type of the lexical symbol. + enum class Type { + // A raw token. + TYPE_TERM, + + // A symbol representing a string of digits. + TYPE_DIGITS, + + // Punctuation characters. + TYPE_PUNCTUATION, + + // A predefined parse tree. + TYPE_PARSE_TREE + }; + + explicit Symbol() = default; + + // Constructs a symbol of a given type with an anchor in the text. + Symbol(const Type type, const CodepointSpan codepoint_span, + const int match_offset, StringPiece lexeme) + : type(type), + codepoint_span(codepoint_span), + match_offset(match_offset), + lexeme(lexeme) {} + + // Constructs a symbol from a pre-defined parse tree. + explicit Symbol(ParseTree* parse_tree) + : type(Type::TYPE_PARSE_TREE), + codepoint_span(parse_tree->codepoint_span), + match_offset(parse_tree->match_offset), + parse_tree(parse_tree) {} + + // The type of the symbol. + Type type; + + // The span in the text as codepoint offsets. + CodepointSpan codepoint_span; + + // The match start offset (including preceding whitespace) as codepoint + // offset. + int match_offset; + + // The symbol text value. + StringPiece lexeme; + + // The predefined parse tree. + ParseTree* parse_tree; +}; + +class Lexer { + public: + explicit Lexer(const UniLib* unilib) : unilib_(*unilib) {} + + // Processes a single token. + // Splits a token into classified symbols. + void AppendTokenSymbols(const StringPiece value, int match_offset, + const CodepointSpan codepoint_span, + std::vector<Symbol>* symbols) const; + + private: + // Gets the type of a character. + Symbol::Type GetSymbolType(const UnicodeText::const_iterator& it) const; + + const UniLib& unilib_; +}; + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_LEXER_H_ diff --git a/utils/grammar/matcher.cc b/utils/grammar/parsing/matcher.cc similarity index 68% rename from utils/grammar/matcher.cc rename to utils/grammar/parsing/matcher.cc index fdc21a3..bdc3f7c 100644 --- a/utils/grammar/matcher.cc +++ b/utils/grammar/parsing/matcher.cc
@@ -13,7 +13,7 @@ // limitations under the License. // -#include "utils/grammar/matcher.h" +#include "utils/grammar/parsing/matcher.h" #include <iostream> #include <limits> @@ -132,7 +132,7 @@ } ++match_length; - // By the loop variant and due to the fact that the strings are sorted, + // By the loop invariant and due to the fact that the strings are sorted, // a matching string will be at `left` now. if (!input_iterator.HasNext()) { const int string_offset = LittleEndian::ToHost32(offsets[left]); @@ -219,7 +219,7 @@ } inline void GetLhs(const RulesSet* rules_set, const int lhs_entry, - Nonterm* nonterminal, CallbackId* callback, uint64* param, + Nonterm* nonterminal, CallbackId* callback, int64* param, int8* max_whitespace_gap) { if (lhs_entry > 0) { // Direct encoding of the nonterminal. @@ -238,27 +238,18 @@ } // namespace -void Matcher::Reset() { - state_ = STATE_DEFAULT; - arena_.Reset(); - pending_items_ = nullptr; - pending_exclusion_items_ = nullptr; - std::fill(chart_.begin(), chart_.end(), nullptr); - last_end_ = std::numeric_limits<int>().lowest(); -} - void Matcher::Finish() { // Check any pending items. ProcessPendingExclusionMatches(); } -void Matcher::QueueForProcessing(Match* item) { +void Matcher::QueueForProcessing(ParseTree* item) { // Push element to the front. item->next = pending_items_; pending_items_ = item; } -void Matcher::QueueForPostCheck(ExclusionMatch* item) { +void Matcher::QueueForPostCheck(ExclusionNode* item) { // Push element to the front. item->next = pending_exclusion_items_; pending_exclusion_items_ = item; @@ -284,11 +275,11 @@ ExecuteLhsSet( codepoint_span, match_offset, /*whitespace_gap=*/(codepoint_span.first - match_offset), - [terminal](Match* match) { - match->terminal = terminal.data(); - match->rhs2 = nullptr; + [terminal](ParseTree* parse_tree) { + parse_tree->terminal = terminal.data(); + parse_tree->rhs2 = nullptr; }, - lhs_set, delegate_); + lhs_set); } // Try case-insensitive matches. @@ -300,42 +291,41 @@ ExecuteLhsSet( codepoint_span, match_offset, /*whitespace_gap=*/(codepoint_span.first - match_offset), - [terminal](Match* match) { - match->terminal = terminal.data(); - match->rhs2 = nullptr; + [terminal](ParseTree* parse_tree) { + parse_tree->terminal = terminal.data(); + parse_tree->rhs2 = nullptr; }, - lhs_set, delegate_); + lhs_set); } } ProcessPendingSet(); } -void Matcher::AddMatch(Match* match) { - TC3_CHECK_GE(match->codepoint_span.second, last_end_); +void Matcher::AddParseTree(ParseTree* parse_tree) { + TC3_CHECK_GE(parse_tree->codepoint_span.second, last_end_); // Finish any pending post-checks. - if (match->codepoint_span.second > last_end_) { + if (parse_tree->codepoint_span.second > last_end_) { ProcessPendingExclusionMatches(); } - last_end_ = match->codepoint_span.second; - QueueForProcessing(match); + last_end_ = parse_tree->codepoint_span.second; + QueueForProcessing(parse_tree); ProcessPendingSet(); } -void Matcher::ExecuteLhsSet(const CodepointSpan codepoint_span, - const int match_offset_bytes, - const int whitespace_gap, - const std::function<void(Match*)>& initializer, - const RulesSet_::LhsSet* lhs_set, - CallbackDelegate* delegate) { +void Matcher::ExecuteLhsSet( + const CodepointSpan codepoint_span, const int match_offset_bytes, + const int whitespace_gap, + const std::function<void(ParseTree*)>& initializer_fn, + const RulesSet_::LhsSet* lhs_set) { TC3_CHECK(lhs_set); - Match* match = nullptr; + ParseTree* parse_tree = nullptr; Nonterm prev_lhs = kUnassignedNonterm; for (const int32 lhs_entry : *lhs_set->lhs()) { Nonterm lhs; CallbackId callback_id; - uint64 callback_param; + int64 callback_param; int8 max_whitespace_gap; GetLhs(rules_, lhs_entry, &lhs, &callback_id, &callback_param, &max_whitespace_gap); @@ -345,91 +335,70 @@ continue; } - // Handle default callbacks. + // Handle callbacks. switch (static_cast<DefaultCallback>(callback_id)) { - case DefaultCallback::kSetType: { - Match* typed_match = AllocateAndInitMatch<Match>(lhs, codepoint_span, - match_offset_bytes); - initializer(typed_match); - typed_match->type = callback_param; - QueueForProcessing(typed_match); - continue; - } case DefaultCallback::kAssertion: { - AssertionMatch* assertion_match = AllocateAndInitMatch<AssertionMatch>( - lhs, codepoint_span, match_offset_bytes); - initializer(assertion_match); - assertion_match->type = Match::kAssertionMatch; - assertion_match->negative = (callback_param != 0); - QueueForProcessing(assertion_match); + AssertionNode* assertion_node = arena_->AllocAndInit<AssertionNode>( + lhs, codepoint_span, match_offset_bytes, + /*negative=*/(callback_param != 0)); + initializer_fn(assertion_node); + QueueForProcessing(assertion_node); continue; } case DefaultCallback::kMapping: { - MappingMatch* mapping_match = AllocateAndInitMatch<MappingMatch>( - lhs, codepoint_span, match_offset_bytes); - initializer(mapping_match); - mapping_match->type = Match::kMappingMatch; - mapping_match->id = callback_param; - QueueForProcessing(mapping_match); + MappingNode* mapping_node = arena_->AllocAndInit<MappingNode>( + lhs, codepoint_span, match_offset_bytes, /*id=*/callback_param); + initializer_fn(mapping_node); + QueueForProcessing(mapping_node); continue; } case DefaultCallback::kExclusion: { // We can only check the exclusion once all matches up to this position // have been processed. Schedule and post check later. - ExclusionMatch* exclusion_match = AllocateAndInitMatch<ExclusionMatch>( - lhs, codepoint_span, match_offset_bytes); - initializer(exclusion_match); - exclusion_match->exclusion_nonterm = callback_param; - QueueForPostCheck(exclusion_match); + ExclusionNode* exclusion_node = arena_->AllocAndInit<ExclusionNode>( + lhs, codepoint_span, match_offset_bytes, + /*exclusion_nonterm=*/callback_param); + initializer_fn(exclusion_node); + QueueForPostCheck(exclusion_node); + continue; + } + case DefaultCallback::kSemanticExpression: { + SemanticExpressionNode* expression_node = + arena_->AllocAndInit<SemanticExpressionNode>( + lhs, codepoint_span, match_offset_bytes, + /*expression=*/ + rules_->semantic_expression()->Get(callback_param)); + initializer_fn(expression_node); + QueueForProcessing(expression_node); continue; } default: break; } - if (callback_id != kNoCallback && rules_->callback() != nullptr) { - const RulesSet_::CallbackEntry* callback_info = - rules_->callback()->LookupByKey(callback_id); - if (callback_info && callback_info->value().is_filter()) { - // Filter callback. - Match candidate; - candidate.Init(lhs, codepoint_span, match_offset_bytes); - initializer(&candidate); - delegate->MatchFound(&candidate, callback_id, callback_param, this); - continue; - } - } - if (prev_lhs != lhs) { prev_lhs = lhs; - match = - AllocateAndInitMatch<Match>(lhs, codepoint_span, match_offset_bytes); - initializer(match); - QueueForProcessing(match); + parse_tree = arena_->AllocAndInit<ParseTree>( + lhs, codepoint_span, match_offset_bytes, ParseTree::Type::kDefault); + initializer_fn(parse_tree); + QueueForProcessing(parse_tree); } - if (callback_id != kNoCallback) { - // This is an output callback. - delegate->MatchFound(match, callback_id, callback_param, this); + if (static_cast<DefaultCallback>(callback_id) == + DefaultCallback::kRootRule) { + chart_.AddDerivation(Derivation{parse_tree, /*rule_id=*/callback_param}); } } } void Matcher::ProcessPendingSet() { - // Avoid recursion caused by: - // ProcessPendingSet --> callback --> AddMatch --> ProcessPendingSet --> ... - if (state_ == STATE_PROCESSING) { - return; - } - state_ = STATE_PROCESSING; while (pending_items_) { // Process. - Match* item = pending_items_; + ParseTree* item = pending_items_; pending_items_ = pending_items_->next; // Add it to the chart. - item->next = chart_[item->codepoint_span.second & kChartHashTableBitmask]; - chart_[item->codepoint_span.second & kChartHashTableBitmask] = item; + chart_.Add(item); // Check unary rules that trigger. for (const RulesSet_::Rules* shard : rules_shards_) { @@ -439,26 +408,19 @@ item->codepoint_span, item->match_offset, /*whitespace_gap=*/ (item->codepoint_span.first - item->match_offset), - [item](Match* match) { - match->rhs1 = nullptr; - match->rhs2 = item; + [item](ParseTree* parse_tree) { + parse_tree->rhs1 = nullptr; + parse_tree->rhs2 = item; }, - lhs_set, delegate_); + lhs_set); } } // Check binary rules that trigger. // Lookup by begin. - Match* prev = chart_[item->match_offset & kChartHashTableBitmask]; - // The chain of items is in decreasing `end` order. - // Find the ones that have prev->end == item->begin. - while (prev != nullptr && - (prev->codepoint_span.second > item->match_offset)) { - prev = prev->next; - } - for (; - prev != nullptr && (prev->codepoint_span.second == item->match_offset); - prev = prev->next) { + for (Chart<>::Iterator it = chart_.MatchesEndingAt(item->match_offset); + !it.Done(); it.Next()) { + const ParseTree* prev = it.Item(); for (const RulesSet_::Rules* shard : rules_shards_) { if (const RulesSet_::LhsSet* lhs_set = FindBinaryRulesMatches(rules_, shard, {prev->lhs, item->lhs})) { @@ -470,45 +432,27 @@ (item->codepoint_span.first - item->match_offset), // Whitespace gap is the gap // between the two parts. - [prev, item](Match* match) { - match->rhs1 = prev; - match->rhs2 = item; + [prev, item](ParseTree* parse_tree) { + parse_tree->rhs1 = prev; + parse_tree->rhs2 = item; }, - lhs_set, delegate_); + lhs_set); } } } } - state_ = STATE_DEFAULT; } void Matcher::ProcessPendingExclusionMatches() { while (pending_exclusion_items_) { - ExclusionMatch* item = pending_exclusion_items_; - pending_exclusion_items_ = static_cast<ExclusionMatch*>(item->next); + ExclusionNode* item = pending_exclusion_items_; + pending_exclusion_items_ = static_cast<ExclusionNode*>(item->next); // Check that the exclusion condition is fulfilled. - if (!ContainsMatch(item->exclusion_nonterm, item->codepoint_span)) { - AddMatch(item); + if (!chart_.HasMatch(item->exclusion_nonterm, item->codepoint_span)) { + AddParseTree(item); } } } -bool Matcher::ContainsMatch(const Nonterm nonterm, - const CodepointSpan& span) const { - // Lookup by end. - Match* match = chart_[span.second & kChartHashTableBitmask]; - // The chain of items is in decreasing `end` order. - while (match != nullptr && match->codepoint_span.second > span.second) { - match = match->next; - } - while (match != nullptr && match->codepoint_span.second == span.second) { - if (match->lhs == nonterm && match->codepoint_span.first == span.first) { - return true; - } - match = match->next; - } - return false; -} - } // namespace libtextclassifier3::grammar diff --git a/utils/grammar/parsing/matcher.h b/utils/grammar/parsing/matcher.h new file mode 100644 index 0000000..5ee2bcc --- /dev/null +++ b/utils/grammar/parsing/matcher.h
@@ -0,0 +1,150 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// A token based context-free grammar matcher. +// +// A parser passes token to the matcher: literal terminal strings and token +// types. +// The parser passes each token along with the [begin, end) position range +// in which it occurs. So for an input string "Groundhog February 2, 2007", the +// parser would tell the matcher that: +// +// "Groundhog" occurs at [0, 9) +// "February" occurs at [9, 18) +// <digits> occurs at [18, 20) +// "," occurs at [20, 21) +// <digits> occurs at [21, 26) +// +// Multiple overlapping symbols can be passed. +// The only constraint on symbol order is that they have to be passed in +// left-to-right order, strictly speaking, their "end" positions must be +// nondecreasing. This constraint allows a more efficient matching algorithm. +// The "begin" positions can be in any order. + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_MATCHER_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_MATCHER_H_ + +#include <array> +#include <functional> +#include <vector> + +#include "annotator/types.h" +#include "utils/base/arena.h" +#include "utils/grammar/parsing/chart.h" +#include "utils/grammar/parsing/derivation.h" +#include "utils/grammar/parsing/parse-tree.h" +#include "utils/grammar/rules_generated.h" +#include "utils/strings/stringpiece.h" +#include "utils/utf8/unilib.h" + +namespace libtextclassifier3::grammar { + +class Matcher { + public: + explicit Matcher(const UniLib* unilib, const RulesSet* rules, + const std::vector<const RulesSet_::Rules*> rules_shards, + UnsafeArena* arena) + : unilib_(*unilib), + arena_(arena), + last_end_(std::numeric_limits<int>().lowest()), + rules_(rules), + rules_shards_(rules_shards), + pending_items_(nullptr), + pending_exclusion_items_(nullptr) { + TC3_CHECK_NE(rules, nullptr); + } + + explicit Matcher(const UniLib* unilib, const RulesSet* rules, + UnsafeArena* arena) + : Matcher(unilib, rules, {}, arena) { + rules_shards_.reserve(rules->rules()->size()); + rules_shards_.insert(rules_shards_.end(), rules->rules()->begin(), + rules->rules()->end()); + } + + // Finish the matching. + void Finish(); + + // Tells the matcher that the given terminal was found occupying position + // range [begin, end) in the input. + // The matcher may invoke callback functions before returning, if this + // terminal triggers any new matches for rules in the grammar. + // Calls to AddTerminal() and AddParseTree() must be in left-to-right order, + // that is, the sequence of `end` values must be non-decreasing. + void AddTerminal(const CodepointSpan codepoint_span, const int match_offset, + StringPiece terminal); + void AddTerminal(const CodepointIndex begin, const CodepointIndex end, + StringPiece terminal) { + AddTerminal(CodepointSpan{begin, end}, begin, terminal); + } + + // Adds predefined parse tree. + void AddParseTree(ParseTree* parse_tree); + + const Chart<> chart() const { return chart_; } + + private: + // Process matches from lhs set. + void ExecuteLhsSet(const CodepointSpan codepoint_span, const int match_offset, + const int whitespace_gap, + const std::function<void(ParseTree*)>& initializer_fn, + const RulesSet_::LhsSet* lhs_set); + + // Queues a newly created match item. + void QueueForProcessing(ParseTree* item); + + // Queues a match item for later post checking of the exclusion condition. + // For exclusions we need to check that the `item->excluded_nonterminal` + // doesn't match the same span. As we cannot know which matches have already + // been added, we queue the item for later post checking - once all matches + // up to `item->codepoint_span.second` have been added. + void QueueForPostCheck(ExclusionNode* item); + + // Adds pending items to the chart, possibly generating new matches as a + // result. + void ProcessPendingSet(); + + // Checks all pending exclusion matches that their exclusion condition is + // fulfilled. + void ProcessPendingExclusionMatches(); + + UniLib unilib_; + + // Memory arena for match allocation. + UnsafeArena* arena_; + + // The end position of the most recent match or terminal, for sanity + // checking. + int last_end_; + + // Rules. + const RulesSet* rules_; + // The active rule shards. + std::vector<const RulesSet_::Rules*> rules_shards_; + + // The set of items pending to be added to the chart as a singly-linked list. + ParseTree* pending_items_; + + // The set of items pending to be post-checked as a singly-linked list. + ExclusionNode* pending_exclusion_items_; + + // The chart data structure: a hashtable containing all matches, indexed by + // their end positions. + Chart<> chart_; +}; + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_MATCHER_H_ diff --git a/utils/grammar/parsing/parse-tree.cc b/utils/grammar/parsing/parse-tree.cc new file mode 100644 index 0000000..8f69394 --- /dev/null +++ b/utils/grammar/parsing/parse-tree.cc
@@ -0,0 +1,54 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "utils/grammar/parsing/parse-tree.h" + +#include <algorithm> +#include <stack> + +namespace libtextclassifier3::grammar { + +void Traverse(const ParseTree* root, + const std::function<bool(const ParseTree*)>& node_fn) { + std::stack<const ParseTree*> open; + open.push(root); + + while (!open.empty()) { + const ParseTree* node = open.top(); + open.pop(); + if (!node_fn(node) || node->IsLeaf()) { + continue; + } + open.push(node->rhs2); + if (node->rhs1 != nullptr) { + open.push(node->rhs1); + } + } +} + +std::vector<const ParseTree*> SelectAll( + const ParseTree* root, + const std::function<bool(const ParseTree*)>& pred_fn) { + std::vector<const ParseTree*> result; + Traverse(root, [&result, pred_fn](const ParseTree* node) { + if (pred_fn(node)) { + result.push_back(node); + } + return true; + }); + return result; +} + +} // namespace libtextclassifier3::grammar diff --git a/utils/grammar/parsing/parse-tree.h b/utils/grammar/parsing/parse-tree.h new file mode 100644 index 0000000..0648530 --- /dev/null +++ b/utils/grammar/parsing/parse-tree.h
@@ -0,0 +1,194 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSE_TREE_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSE_TREE_H_ + +#include <functional> +#include <vector> + +#include "annotator/types.h" +#include "utils/grammar/semantics/expression_generated.h" +#include "utils/grammar/types.h" +#include "utils/strings/stringpiece.h" + +namespace libtextclassifier3::grammar { + +// Represents a parse tree for a match that was found for a nonterminal. +struct ParseTree { + enum class Type : int8 { + // Default, untyped match. + kDefault = 0, + + // An assertion match (see: AssertionNode). + kAssertion = 1, + + // A value mapping match (see: MappingNode). + kMapping = 2, + + // An exclusion match (see: ExclusionNode). + kExclusion = 3, + + // A match for an annotation (see: AnnotationNode). + kAnnotation = 4, + + // A match for a semantic annotation (see: SemanticExpressionNode). + kExpression = 5, + }; + + explicit ParseTree() = default; + explicit ParseTree(const Nonterm lhs, const CodepointSpan& codepoint_span, + const int match_offset, const Type type) + : lhs(lhs), + type(type), + codepoint_span(codepoint_span), + match_offset(match_offset) {} + + // For binary rule matches: rhs1 != NULL and rhs2 != NULL + // unary rule matches: rhs1 == NULL and rhs2 != NULL + // terminal rule matches: rhs1 != NULL and rhs2 == NULL + // custom leaves: rhs1 == NULL and rhs2 == NULL + bool IsInteriorNode() const { return rhs2 != nullptr; } + bool IsLeaf() const { return !rhs2; } + + bool IsBinaryRule() const { return rhs1 && rhs2; } + bool IsUnaryRule() const { return !rhs1 && rhs2; } + bool IsTerminalRule() const { return rhs1 && !rhs2; } + bool HasLeadingWhitespace() const { + return codepoint_span.first != match_offset; + } + + const ParseTree* unary_rule_rhs() const { return rhs2; } + + // Used in singly-linked queue of matches for processing. + ParseTree* next = nullptr; + + // Nonterminal we found a match for. + Nonterm lhs = kUnassignedNonterm; + + // Type of the match. + Type type = Type::kDefault; + + // The span in codepoints. + CodepointSpan codepoint_span; + + // The begin codepoint offset used during matching. + // This is usually including any prefix whitespace. + int match_offset; + + union { + // The first sub match for binary rules. + const ParseTree* rhs1 = nullptr; + + // The terminal, for terminal rules. + const char* terminal; + }; + // First or second sub-match for interior nodes. + const ParseTree* rhs2 = nullptr; +}; + +// Node type to keep track of associated values. +struct MappingNode : public ParseTree { + explicit MappingNode(const Nonterm arg_lhs, + const CodepointSpan arg_codepoint_span, + const int arg_match_offset, const int64 arg_value) + : ParseTree(arg_lhs, arg_codepoint_span, arg_match_offset, + Type::kMapping), + id(arg_value) {} + // The associated id or value. + int64 id; +}; + +// Node type to keep track of assertions. +struct AssertionNode : public ParseTree { + explicit AssertionNode(const Nonterm arg_lhs, + const CodepointSpan arg_codepoint_span, + const int arg_match_offset, const bool arg_negative) + : ParseTree(arg_lhs, arg_codepoint_span, arg_match_offset, + Type::kAssertion), + negative(arg_negative) {} + // If true, the assertion is negative and will be valid if the input doesn't + // match. + bool negative; +}; + +// Node type to define exclusions. +struct ExclusionNode : public ParseTree { + explicit ExclusionNode(const Nonterm arg_lhs, + const CodepointSpan arg_codepoint_span, + const int arg_match_offset, + const Nonterm arg_exclusion_nonterm) + : ParseTree(arg_lhs, arg_codepoint_span, arg_match_offset, + Type::kExclusion), + exclusion_nonterm(arg_exclusion_nonterm) {} + // The nonterminal that denotes matches to exclude from a successful match. + // So the match is only valid if there is no match of `exclusion_nonterm` + // spanning the same text range. + Nonterm exclusion_nonterm; +}; + +// Match to represent an annotator annotated span in the grammar. +struct AnnotationNode : public ParseTree { + explicit AnnotationNode(const Nonterm arg_lhs, + const CodepointSpan arg_codepoint_span, + const int arg_match_offset, + const ClassificationResult* arg_annotation) + : ParseTree(arg_lhs, arg_codepoint_span, arg_match_offset, + Type::kAnnotation), + annotation(arg_annotation) {} + const ClassificationResult* annotation; +}; + +// Node type to represent an associated semantic expression. +struct SemanticExpressionNode : public ParseTree { + explicit SemanticExpressionNode(const Nonterm arg_lhs, + const CodepointSpan arg_codepoint_span, + const int arg_match_offset, + const SemanticExpression* arg_expression) + : ParseTree(arg_lhs, arg_codepoint_span, arg_match_offset, + Type::kExpression), + expression(arg_expression) {} + const SemanticExpression* expression; +}; + +// Utility functions for parse tree traversal. + +// Does a preorder traversal, calling `node_fn` on each node. +// `node_fn` is expected to return whether to continue expanding a node. +void Traverse(const ParseTree* root, + const std::function<bool(const ParseTree*)>& node_fn); + +// Does a preorder traversal, selecting all nodes where `pred_fn` returns true. +std::vector<const ParseTree*> SelectAll( + const ParseTree* root, + const std::function<bool(const ParseTree*)>& pred_fn); + +// Retrieves all nodes of a given type. +template <typename T> +const std::vector<const T*> SelectAllOfType(const ParseTree* root, + const ParseTree::Type type) { + std::vector<const T*> result; + Traverse(root, [&result, type](const ParseTree* node) { + if (node->type == type) { + result.push_back(static_cast<const T*>(node)); + } + return true; + }); + return result; +} + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSE_TREE_H_ diff --git a/utils/grammar/parsing/parser.cc b/utils/grammar/parsing/parser.cc new file mode 100644 index 0000000..5efca93 --- /dev/null +++ b/utils/grammar/parsing/parser.cc
@@ -0,0 +1,277 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "utils/grammar/parsing/parser.h" + +#include <unordered_map> + +#include "utils/grammar/parsing/parse-tree.h" +#include "utils/grammar/rules-utils.h" +#include "utils/grammar/types.h" +#include "utils/zlib/tclib_zlib.h" +#include "utils/zlib/zlib_regex.h" + +namespace libtextclassifier3::grammar { +namespace { + +inline bool CheckMemoryUsage(const UnsafeArena* arena) { + // The maximum memory usage for matching. + constexpr int kMaxMemoryUsage = 1 << 20; + return arena->status().bytes_allocated() <= kMaxMemoryUsage; +} + +// Maps a codepoint to include the token padding if it aligns with a token +// start. Whitespace is ignored when symbols are fed to the matcher. Preceding +// whitespace is merged to the match start so that tokens and non-terminals +// appear next to each other without whitespace. For text or regex annotations, +// we therefore merge the whitespace padding to the start if the annotation +// starts at a token. +int MapCodepointToTokenPaddingIfPresent( + const std::unordered_map<CodepointIndex, CodepointIndex>& token_alignment, + const int start) { + const auto it = token_alignment.find(start); + if (it != token_alignment.end()) { + return it->second; + } + return start; +} + +} // namespace + +Parser::Parser(const UniLib* unilib, const RulesSet* rules) + : unilib_(*unilib), + rules_(rules), + lexer_(unilib), + nonterminals_(rules_->nonterminals()), + rules_locales_(ParseRulesLocales(rules_)), + regex_annotators_(BuildRegexAnnotators()) {} + +// Uncompresses and build the defined regex annotators. +std::vector<Parser::RegexAnnotator> Parser::BuildRegexAnnotators() const { + std::vector<RegexAnnotator> result; + if (rules_->regex_annotator() != nullptr) { + std::unique_ptr<ZlibDecompressor> decompressor = + ZlibDecompressor::Instance(); + result.reserve(rules_->regex_annotator()->size()); + for (const RulesSet_::RegexAnnotator* regex_annotator : + *rules_->regex_annotator()) { + result.push_back( + {UncompressMakeRegexPattern(unilib_, regex_annotator->pattern(), + regex_annotator->compressed_pattern(), + rules_->lazy_regex_compilation(), + decompressor.get()), + regex_annotator->nonterminal()}); + } + } + return result; +} + +std::vector<Symbol> Parser::SortedSymbolsForInput(const TextContext& input, + UnsafeArena* arena) const { + // Whitespace is ignored when symbols are fed to the matcher. + // For regex matches and existing text annotations we therefore have to merge + // preceding whitespace to the match start so that tokens and non-terminals + // appear as next to each other without whitespace. We keep track of real + // token starts and precending whitespace in `token_match_start`, so that we + // can extend a match's start to include the preceding whitespace. + std::unordered_map<CodepointIndex, CodepointIndex> token_match_start; + for (int i = input.context_span.first + 1; i < input.context_span.second; + i++) { + const CodepointIndex token_start = input.tokens[i].start; + const CodepointIndex prev_token_end = input.tokens[i - 1].end; + if (token_start != prev_token_end) { + token_match_start[token_start] = prev_token_end; + } + } + + std::vector<Symbol> symbols; + CodepointIndex match_offset = input.tokens[input.context_span.first].start; + + // Add start symbol. + if (input.context_span.first == 0 && + nonterminals_->start_nt() != kUnassignedNonterm) { + match_offset = 0; + symbols.emplace_back(arena->AllocAndInit<ParseTree>( + nonterminals_->start_nt(), CodepointSpan{0, 0}, + /*match_offset=*/0, ParseTree::Type::kDefault)); + } + + if (nonterminals_->wordbreak_nt() != kUnassignedNonterm) { + symbols.emplace_back(arena->AllocAndInit<ParseTree>( + nonterminals_->wordbreak_nt(), + CodepointSpan{match_offset, match_offset}, + /*match_offset=*/match_offset, ParseTree::Type::kDefault)); + } + + // Add symbols from tokens. + for (int i = input.context_span.first; i < input.context_span.second; i++) { + const Token& token = input.tokens[i]; + lexer_.AppendTokenSymbols(token.value, /*match_offset=*/match_offset, + CodepointSpan{token.start, token.end}, &symbols); + match_offset = token.end; + + // Add word break symbol. + if (nonterminals_->wordbreak_nt() != kUnassignedNonterm) { + symbols.emplace_back(arena->AllocAndInit<ParseTree>( + nonterminals_->wordbreak_nt(), + CodepointSpan{match_offset, match_offset}, + /*match_offset=*/match_offset, ParseTree::Type::kDefault)); + } + } + + // Add end symbol if used by the grammar. + if (input.context_span.second == input.tokens.size() && + nonterminals_->end_nt() != kUnassignedNonterm) { + symbols.emplace_back(arena->AllocAndInit<ParseTree>( + nonterminals_->end_nt(), CodepointSpan{match_offset, match_offset}, + /*match_offset=*/match_offset, ParseTree::Type::kDefault)); + } + + // Add symbols from the regex annotators. + const CodepointIndex context_start = + input.tokens[input.context_span.first].start; + const CodepointIndex context_end = + input.tokens[input.context_span.second - 1].end; + for (const RegexAnnotator& regex_annotator : regex_annotators_) { + std::unique_ptr<UniLib::RegexMatcher> regex_matcher = + regex_annotator.pattern->Matcher(UnicodeText::Substring( + input.text, context_start, context_end, /*do_copy=*/false)); + int status = UniLib::RegexMatcher::kNoError; + while (regex_matcher->Find(&status) && + status == UniLib::RegexMatcher::kNoError) { + const CodepointSpan span{regex_matcher->Start(0, &status) + context_start, + regex_matcher->End(0, &status) + context_start}; + symbols.emplace_back(arena->AllocAndInit<ParseTree>( + regex_annotator.nonterm, span, /*match_offset=*/ + MapCodepointToTokenPaddingIfPresent(token_match_start, span.first), + ParseTree::Type::kDefault)); + } + } + + // Add symbols based on annotations. + if (auto annotation_nonterminals = nonterminals_->annotation_nt()) { + for (const AnnotatedSpan& annotated_span : input.annotations) { + const ClassificationResult& classification = + annotated_span.classification.front(); + if (auto entry = annotation_nonterminals->LookupByKey( + classification.collection.c_str())) { + symbols.emplace_back(arena->AllocAndInit<AnnotationNode>( + entry->value(), annotated_span.span, /*match_offset=*/ + MapCodepointToTokenPaddingIfPresent(token_match_start, + annotated_span.span.first), + &classification)); + } + } + } + + std::sort(symbols.begin(), symbols.end(), + [](const Symbol& a, const Symbol& b) { + // Sort by increasing (end, start) position to guarantee the + // matcher requirement that the tokens are fed in non-decreasing + // end position order. + return std::tie(a.codepoint_span.second, a.codepoint_span.first) < + std::tie(b.codepoint_span.second, b.codepoint_span.first); + }); + + return symbols; +} + +void Parser::EmitSymbol(const Symbol& symbol, UnsafeArena* arena, + Matcher* matcher) const { + if (!CheckMemoryUsage(arena)) { + return; + } + switch (symbol.type) { + case Symbol::Type::TYPE_PARSE_TREE: { + // Just emit the parse tree. + matcher->AddParseTree(symbol.parse_tree); + return; + } + case Symbol::Type::TYPE_DIGITS: { + // Emit <digits> if used by the rules. + if (nonterminals_->digits_nt() != kUnassignedNonterm) { + matcher->AddParseTree(arena->AllocAndInit<ParseTree>( + nonterminals_->digits_nt(), symbol.codepoint_span, + symbol.match_offset, ParseTree::Type::kDefault)); + } + + // Emit <n_digits> if used by the rules. + if (nonterminals_->n_digits_nt() != nullptr) { + const int num_digits = + symbol.codepoint_span.second - symbol.codepoint_span.first; + if (num_digits <= nonterminals_->n_digits_nt()->size()) { + const Nonterm n_digits_nt = + nonterminals_->n_digits_nt()->Get(num_digits - 1); + if (n_digits_nt != kUnassignedNonterm) { + matcher->AddParseTree(arena->AllocAndInit<ParseTree>( + nonterminals_->n_digits_nt()->Get(num_digits - 1), + symbol.codepoint_span, symbol.match_offset, + ParseTree::Type::kDefault)); + } + } + } + break; + } + case Symbol::Type::TYPE_TERM: { + // Emit <uppercase_token> if used by the rules. + if (nonterminals_->uppercase_token_nt() != 0 && + unilib_.IsUpperText( + UTF8ToUnicodeText(symbol.lexeme, /*do_copy=*/false))) { + matcher->AddParseTree(arena->AllocAndInit<ParseTree>( + nonterminals_->uppercase_token_nt(), symbol.codepoint_span, + symbol.match_offset, ParseTree::Type::kDefault)); + } + break; + } + default: + break; + } + + // Emit the token as terminal. + matcher->AddTerminal(symbol.codepoint_span, symbol.match_offset, + symbol.lexeme); + + // Emit <token> if used by rules. + matcher->AddParseTree(arena->AllocAndInit<ParseTree>( + nonterminals_->token_nt(), symbol.codepoint_span, symbol.match_offset, + ParseTree::Type::kDefault)); +} + +// Parses an input text and returns the root rule derivations. +std::vector<Derivation> Parser::Parse(const TextContext& input, + UnsafeArena* arena) const { + // Check the tokens, input can be non-empty (whitespace) but have no tokens. + if (input.tokens.empty()) { + return {}; + } + + // Select locale matching rules. + std::vector<const RulesSet_::Rules*> locale_rules = + SelectLocaleMatchingShards(rules_, rules_locales_, input.locales); + + if (locale_rules.empty()) { + // Nothing to do. + return {}; + } + + Matcher matcher(&unilib_, rules_, locale_rules, arena); + for (const Symbol& symbol : SortedSymbolsForInput(input, arena)) { + EmitSymbol(symbol, arena, &matcher); + } + matcher.Finish(); + return matcher.chart().derivations(); +} + +} // namespace libtextclassifier3::grammar diff --git a/utils/grammar/parsing/parser.h b/utils/grammar/parsing/parser.h new file mode 100644 index 0000000..d96bfdc --- /dev/null +++ b/utils/grammar/parsing/parser.h
@@ -0,0 +1,81 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSER_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSER_H_ + +#include <vector> + +#include "annotator/types.h" +#include "utils/base/arena.h" +#include "utils/grammar/parsing/derivation.h" +#include "utils/grammar/parsing/lexer.h" +#include "utils/grammar/parsing/matcher.h" +#include "utils/grammar/rules_generated.h" +#include "utils/grammar/text-context.h" +#include "utils/i18n/locale.h" +#include "utils/utf8/unilib.h" + +namespace libtextclassifier3::grammar { + +// Syntactic parsing pass. +// The parser validates and deduplicates candidates produced by the grammar +// matcher. It augments the parse trees with derivation information for semantic +// evaluation. +class Parser { + public: + explicit Parser(const UniLib* unilib, const RulesSet* rules); + + // Parses an input text and returns the root rule derivations. + std::vector<Derivation> Parse(const TextContext& input, + UnsafeArena* arena) const; + + private: + struct RegexAnnotator { + std::unique_ptr<UniLib::RegexPattern> pattern; + Nonterm nonterm; + }; + + // Uncompresses and build the defined regex annotators. + std::vector<RegexAnnotator> BuildRegexAnnotators() const; + + // Produces symbols for a text input to feed to a matcher. + // These are symbols for each tokens from the lexer, existing text annotations + // and regex annotations. + // The symbols are sorted with increasing end-positions to satisfy the matcher + // requirements. + std::vector<Symbol> SortedSymbolsForInput(const TextContext& input, + UnsafeArena* arena) const; + + // Emits a symbol to the matcher. + void EmitSymbol(const Symbol& symbol, UnsafeArena* arena, + Matcher* matcher) const; + + const UniLib& unilib_; + const RulesSet* rules_; + const Lexer lexer_; + + // Pre-defined nonterminals. + const RulesSet_::Nonterminals* nonterminals_; + + // Pre-parsed locales of the rules. + const std::vector<std::vector<Locale>> rules_locales_; + + std::vector<RegexAnnotator> regex_annotators_; +}; + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_PARSER_H_ diff --git a/utils/grammar/rules-utils.cc b/utils/grammar/rules-utils.cc index ab1c45c..44e1a1d 100644 --- a/utils/grammar/rules-utils.cc +++ b/utils/grammar/rules-utils.cc
@@ -53,70 +53,4 @@ return shards; } -std::vector<Derivation> DeduplicateDerivations( - const std::vector<Derivation>& derivations) { - std::vector<Derivation> sorted_candidates = derivations; - std::stable_sort( - sorted_candidates.begin(), sorted_candidates.end(), - [](const Derivation& a, const Derivation& b) { - // Sort by id. - if (a.rule_id != b.rule_id) { - return a.rule_id < b.rule_id; - } - - // Sort by increasing start. - if (a.match->codepoint_span.first != b.match->codepoint_span.first) { - return a.match->codepoint_span.first < b.match->codepoint_span.first; - } - - // Sort by decreasing end. - return a.match->codepoint_span.second > b.match->codepoint_span.second; - }); - - // Deduplicate by overlap. - std::vector<Derivation> result; - for (int i = 0; i < sorted_candidates.size(); i++) { - const Derivation& candidate = sorted_candidates[i]; - bool eliminated = false; - - // Due to the sorting above, the candidate can only be completely - // intersected by a match before it in the sorted order. - for (int j = i - 1; j >= 0; j--) { - if (sorted_candidates[j].rule_id != candidate.rule_id) { - break; - } - if (sorted_candidates[j].match->codepoint_span.first <= - candidate.match->codepoint_span.first && - sorted_candidates[j].match->codepoint_span.second >= - candidate.match->codepoint_span.second) { - eliminated = true; - break; - } - } - - if (!eliminated) { - result.push_back(candidate); - } - } - return result; -} - -bool VerifyAssertions(const Match* match) { - bool result = true; - grammar::Traverse(match, [&result](const Match* node) { - if (node->type != Match::kAssertionMatch) { - // Only validation if all checks so far passed. - return result; - } - - // Positive assertions are by definition fulfilled, - // fail if the assertion is negative. - if (static_cast<const AssertionMatch*>(node)->negative) { - result = false; - } - return result; - }); - return result; -} - } // namespace libtextclassifier3::grammar diff --git a/utils/grammar/rules-utils.h b/utils/grammar/rules-utils.h index 8664e95..68a6ae0 100644 --- a/utils/grammar/rules-utils.h +++ b/utils/grammar/rules-utils.h
@@ -13,17 +13,13 @@ // limitations under the License. // -#pragma GCC diagnostic ignored "-Wc++17-extensions" - // Auxiliary methods for using rules. #ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_RULES_UTILS_H_ #define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_RULES_UTILS_H_ -#include <unordered_map> #include <vector> -#include "utils/grammar/match.h" #include "utils/grammar/rules_generated.h" #include "utils/i18n/locale.h" @@ -38,22 +34,6 @@ const std::vector<std::vector<Locale>>& shard_locales, const std::vector<Locale>& locales); -// Deduplicates rule derivations by containing overlap. -// The grammar system can output multiple candidates for optional parts. -// For example if a rule has an optional suffix, we -// will get two rule derivations when the suffix is present: one with and one -// without the suffix. We therefore deduplicate by containing overlap, viz. from -// two candidates we keep the longer one if it completely contains the shorter. -struct Derivation { - const Match* match; - int64 rule_id; -}; -std::vector<Derivation> DeduplicateDerivations( - const std::vector<Derivation>& derivations); - -// Checks that all assertions of a match tree are fulfilled. -bool VerifyAssertions(const Match* match); - } // namespace libtextclassifier3::grammar #endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_RULES_UTILS_H_ diff --git a/utils/grammar/rules.fbs b/utils/grammar/rules.fbs index 2a8055e..b85bb3c 100755 --- a/utils/grammar/rules.fbs +++ b/utils/grammar/rules.fbs
@@ -13,7 +13,7 @@ // limitations under the License. // -include "utils/grammar/next/semantics/expression.fbs"; +include "utils/grammar/semantics/expression.fbs"; include "utils/zlib/buffer.fbs"; include "utils/i18n/language-tag.fbs"; @@ -147,19 +147,6 @@ annotation_nt:[Nonterminals_.AnnotationNtEntry]; } -// Callback information. -namespace libtextclassifier3.grammar.RulesSet_; -struct Callback { - // Whether the callback is a filter. - is_filter:bool; -} - -namespace libtextclassifier3.grammar.RulesSet_; -struct CallbackEntry { - key:uint (key); - value:Callback; -} - namespace libtextclassifier3.grammar.RulesSet_.DebugInformation_; table NonterminalNamesEntry { key:int (key); @@ -205,7 +192,7 @@ terminals:string; nonterminals:RulesSet_.Nonterminals; - callback:[RulesSet_.CallbackEntry]; + reserved_6:int16 (deprecated); debug_information:RulesSet_.DebugInformation; regex_annotator:[RulesSet_.RegexAnnotator]; @@ -213,7 +200,7 @@ lazy_regex_compilation:bool; // The semantic expressions associated with rule matches. - semantic_expression:[next.SemanticExpression]; + semantic_expression:[SemanticExpression]; // The schema defining the semantic results. semantic_values_schema:[ubyte]; diff --git a/utils/grammar/semantics/composer.cc b/utils/grammar/semantics/composer.cc new file mode 100644 index 0000000..fcf8263 --- /dev/null +++ b/utils/grammar/semantics/composer.cc
@@ -0,0 +1,131 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "utils/grammar/semantics/composer.h" + +#include "utils/base/status_macros.h" +#include "utils/grammar/semantics/evaluators/arithmetic-eval.h" +#include "utils/grammar/semantics/evaluators/compose-eval.h" +#include "utils/grammar/semantics/evaluators/const-eval.h" +#include "utils/grammar/semantics/evaluators/constituent-eval.h" +#include "utils/grammar/semantics/evaluators/merge-values-eval.h" +#include "utils/grammar/semantics/evaluators/parse-number-eval.h" +#include "utils/grammar/semantics/evaluators/span-eval.h" + +namespace libtextclassifier3::grammar { +namespace { + +// Gathers all constituents of a rule and index them. +// The constituents are numbered in the rule construction. But consituents could +// be in optional parts of the rule and might not be present in a match. +// This finds all constituents that are present in a match and allows to +// retrieve them by their index. +std::unordered_map<int, const ParseTree*> GatherConstituents( + const ParseTree* root) { + std::unordered_map<int, const ParseTree*> constituents; + Traverse(root, [root, &constituents](const ParseTree* node) { + switch (node->type) { + case ParseTree::Type::kMapping: + TC3_CHECK(node->IsUnaryRule()); + constituents[static_cast<const MappingNode*>(node)->id] = + node->unary_rule_rhs(); + return false; + case ParseTree::Type::kDefault: + // Continue traversal. + return true; + default: + // Don't continue the traversal if we are not at the root node. + // This could e.g. be an assertion node. + return (node == root); + } + }); + return constituents; +} + +} // namespace + +SemanticComposer::SemanticComposer( + const reflection::Schema* semantic_values_schema) { + evaluators_.emplace(SemanticExpression_::Expression_ArithmeticExpression, + std::make_unique<ArithmeticExpressionEvaluator>(this)); + evaluators_.emplace(SemanticExpression_::Expression_ConstituentExpression, + std::make_unique<ConstituentEvaluator>()); + evaluators_.emplace(SemanticExpression_::Expression_ParseNumberExpression, + std::make_unique<ParseNumberEvaluator>(this)); + evaluators_.emplace(SemanticExpression_::Expression_SpanAsStringExpression, + std::make_unique<SpanAsStringEvaluator>()); + if (semantic_values_schema != nullptr) { + // Register semantic functions. + evaluators_.emplace( + SemanticExpression_::Expression_ComposeExpression, + std::make_unique<ComposeEvaluator>(this, semantic_values_schema)); + evaluators_.emplace( + SemanticExpression_::Expression_ConstValueExpression, + std::make_unique<ConstEvaluator>(semantic_values_schema)); + evaluators_.emplace( + SemanticExpression_::Expression_MergeValueExpression, + std::make_unique<MergeValuesEvaluator>(this, semantic_values_schema)); + } +} + +StatusOr<const SemanticValue*> SemanticComposer::Eval( + const TextContext& text_context, const Derivation& derivation, + UnsafeArena* arena) const { + if (!derivation.parse_tree->IsUnaryRule() || + derivation.parse_tree->unary_rule_rhs()->type != + ParseTree::Type::kExpression) { + return nullptr; + } + return Eval(text_context, + static_cast<const SemanticExpressionNode*>( + derivation.parse_tree->unary_rule_rhs()), + arena); +} + +StatusOr<const SemanticValue*> SemanticComposer::Eval( + const TextContext& text_context, const SemanticExpressionNode* derivation, + UnsafeArena* arena) const { + // Evaluate constituents. + EvalContext context{&text_context, derivation}; + for (const auto& [constituent_index, constituent] : + GatherConstituents(derivation)) { + if (constituent->type == ParseTree::Type::kExpression) { + TC3_ASSIGN_OR_RETURN( + context.rule_constituents[constituent_index], + Eval(text_context, + static_cast<const SemanticExpressionNode*>(constituent), arena)); + } else { + // Just use the text of the constituent if no semantic expression was + // defined. + context.rule_constituents[constituent_index] = SemanticValue::Create( + text_context.Span(constituent->codepoint_span), arena); + } + } + return Apply(context, derivation->expression, arena); +} + +StatusOr<const SemanticValue*> SemanticComposer::Apply( + const EvalContext& context, const SemanticExpression* expression, + UnsafeArena* arena) const { + const auto handler_it = evaluators_.find(expression->expression_type()); + if (handler_it == evaluators_.end()) { + return Status(StatusCode::INVALID_ARGUMENT, + std::string("Unhandled expression type: ") + + EnumNameExpression(expression->expression_type())); + } + return handler_it->second->Apply(context, expression, arena); +} + +} // namespace libtextclassifier3::grammar diff --git a/utils/grammar/semantics/composer.h b/utils/grammar/semantics/composer.h new file mode 100644 index 0000000..2402085 --- /dev/null +++ b/utils/grammar/semantics/composer.h
@@ -0,0 +1,73 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_COMPOSER_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_COMPOSER_H_ + +#include <unordered_map> +#include <vector> + +#include "utils/base/arena.h" +#include "utils/base/status.h" +#include "utils/base/statusor.h" +#include "utils/flatbuffers/flatbuffers.h" +#include "utils/grammar/parsing/derivation.h" +#include "utils/grammar/parsing/parse-tree.h" +#include "utils/grammar/semantics/eval-context.h" +#include "utils/grammar/semantics/evaluator.h" +#include "utils/grammar/semantics/expression_generated.h" +#include "utils/grammar/semantics/value.h" +#include "utils/grammar/text-context.h" + +namespace libtextclassifier3::grammar { + +// Semantic value composer. +// It evaluates a semantic expression of a syntactic parse tree as a semantic +// value. +// It evaluates the constituents of a rule match and applies them to semantic +// expression, calling out to semantic functions that implement the basic +// building blocks. +class SemanticComposer : public SemanticExpressionEvaluator { + public: + // Expects a flatbuffer schema that describes the possible result values of + // an evaluation. + explicit SemanticComposer(const reflection::Schema* semantic_values_schema); + + // Evaluates a semantic expression that is associated with the root of a parse + // tree. + StatusOr<const SemanticValue*> Eval(const TextContext& text_context, + const Derivation& derivation, + UnsafeArena* arena) const; + + // Applies a semantic expression to a list of constituents and + // produces an output semantic value. + StatusOr<const SemanticValue*> Apply(const EvalContext& context, + const SemanticExpression* expression, + UnsafeArena* arena) const override; + + private: + // Evaluates a semantic expression against a parse tree. + StatusOr<const SemanticValue*> Eval(const TextContext& text_context, + const SemanticExpressionNode* derivation, + UnsafeArena* arena) const; + + std::unordered_map<SemanticExpression_::Expression, + std::unique_ptr<SemanticExpressionEvaluator>> + evaluators_; +}; + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_COMPOSER_H_ diff --git a/utils/grammar/semantics/eval-context.h b/utils/grammar/semantics/eval-context.h new file mode 100644 index 0000000..612deb8 --- /dev/null +++ b/utils/grammar/semantics/eval-context.h
@@ -0,0 +1,44 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVAL_CONTEXT_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVAL_CONTEXT_H_ + +#include <unordered_map> + +#include "utils/grammar/parsing/parse-tree.h" +#include "utils/grammar/semantics/value.h" +#include "utils/grammar/text-context.h" + +namespace libtextclassifier3::grammar { + +// Context for the evaluation of the semantic expression of a rule parse tree. +// This contains data about the evaluated constituents (named parts) of a rule +// and it's match. +struct EvalContext { + // The input text. + const TextContext* text_context = nullptr; + + // The syntactic parse tree that is begin evaluated. + const ParseTree* parse_tree = nullptr; + + // A map of an id of a rule constituent (named part of a rule match) to it's + // evaluated semantic value. + std::unordered_map<int, const SemanticValue*> rule_constituents; +}; + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVAL_CONTEXT_H_ diff --git a/utils/grammar/semantics/evaluator.h b/utils/grammar/semantics/evaluator.h new file mode 100644 index 0000000..4ed5a6c --- /dev/null +++ b/utils/grammar/semantics/evaluator.h
@@ -0,0 +1,41 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATOR_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATOR_H_ + +#include "utils/base/arena.h" +#include "utils/base/statusor.h" +#include "utils/grammar/semantics/eval-context.h" +#include "utils/grammar/semantics/expression_generated.h" +#include "utils/grammar/semantics/value.h" + +namespace libtextclassifier3::grammar { + +// Interface for a semantic function that evaluates an expression and returns +// a semantic value. +class SemanticExpressionEvaluator { + public: + virtual ~SemanticExpressionEvaluator() = default; + + // Applies `expression` to the `context` to produce a semantic value. + virtual StatusOr<const SemanticValue*> Apply( + const EvalContext& context, const SemanticExpression* expression, + UnsafeArena* arena) const = 0; +}; + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATOR_H_ diff --git a/utils/grammar/semantics/evaluators/arithmetic-eval.cc b/utils/grammar/semantics/evaluators/arithmetic-eval.cc new file mode 100644 index 0000000..171fdef --- /dev/null +++ b/utils/grammar/semantics/evaluators/arithmetic-eval.cc
@@ -0,0 +1,133 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "utils/grammar/semantics/evaluators/arithmetic-eval.h" + +#include <limits> + +namespace libtextclassifier3::grammar { +namespace { + +template <typename T> +StatusOr<const SemanticValue*> Reduce( + const SemanticExpressionEvaluator* composer, const EvalContext& context, + const ArithmeticExpression* expression, UnsafeArena* arena) { + T result; + switch (expression->op()) { + case ArithmeticExpression_::Operator_OP_ADD: { + result = 0; + break; + } + case ArithmeticExpression_::Operator_OP_MUL: { + result = 1; + break; + } + case ArithmeticExpression_::Operator_OP_MIN: { + result = std::numeric_limits<T>::max(); + break; + } + case ArithmeticExpression_::Operator_OP_MAX: { + result = std::numeric_limits<T>::min(); + break; + } + default: { + return Status(StatusCode::INVALID_ARGUMENT, + "Unexpected op: " + + std::string(ArithmeticExpression_::EnumNameOperator( + expression->op()))); + } + } + if (expression->values() != nullptr) { + for (const SemanticExpression* semantic_expression : + *expression->values()) { + TC3_ASSIGN_OR_RETURN( + const SemanticValue* value, + composer->Apply(context, semantic_expression, arena)); + if (value == nullptr) { + continue; + } + if (!value->Has<T>()) { + return Status( + StatusCode::INVALID_ARGUMENT, + "Argument didn't evaluate as expected type: " + + std::string(reflection::EnumNameBaseType(value->base_type()))); + } + const T scalar_value = value->Value<T>(); + switch (expression->op()) { + case ArithmeticExpression_::Operator_OP_ADD: { + result += scalar_value; + break; + } + case ArithmeticExpression_::Operator_OP_MUL: { + result *= scalar_value; + break; + } + case ArithmeticExpression_::Operator_OP_MIN: { + result = std::min(result, scalar_value); + break; + } + case ArithmeticExpression_::Operator_OP_MAX: { + result = std::max(result, scalar_value); + break; + } + default: { + break; + } + } + } + } + return SemanticValue::Create(result, arena); +} + +} // namespace + +StatusOr<const SemanticValue*> ArithmeticExpressionEvaluator::Apply( + const EvalContext& context, const SemanticExpression* expression, + UnsafeArena* arena) const { + TC3_DCHECK_EQ(expression->expression_type(), + SemanticExpression_::Expression_ArithmeticExpression); + const ArithmeticExpression* arithmetic_expression = + expression->expression_as_ArithmeticExpression(); + switch (arithmetic_expression->base_type()) { + case reflection::BaseType::Byte: + return Reduce<int8>(composer_, context, arithmetic_expression, arena); + case reflection::BaseType::UByte: + return Reduce<uint8>(composer_, context, arithmetic_expression, arena); + case reflection::BaseType::Short: + return Reduce<int16>(composer_, context, arithmetic_expression, arena); + case reflection::BaseType::UShort: + return Reduce<uint16>(composer_, context, arithmetic_expression, arena); + case reflection::BaseType::Int: + return Reduce<int32>(composer_, context, arithmetic_expression, arena); + case reflection::BaseType::UInt: + return Reduce<uint32>(composer_, context, arithmetic_expression, arena); + case reflection::BaseType::Long: + return Reduce<int64>(composer_, context, arithmetic_expression, arena); + case reflection::BaseType::ULong: + return Reduce<uint64>(composer_, context, arithmetic_expression, arena); + case reflection::BaseType::Float: + return Reduce<float>(composer_, context, arithmetic_expression, arena); + case reflection::BaseType::Double: + return Reduce<double>(composer_, context, arithmetic_expression, arena); + default: + return Status(StatusCode::INVALID_ARGUMENT, + "Unsupported for ArithmeticExpression: " + + std::string(reflection::EnumNameBaseType( + static_cast<reflection::BaseType>( + arithmetic_expression->base_type())))); + } +} + +} // namespace libtextclassifier3::grammar diff --git a/utils/grammar/semantics/evaluators/arithmetic-eval.h b/utils/grammar/semantics/evaluators/arithmetic-eval.h new file mode 100644 index 0000000..aafd513 --- /dev/null +++ b/utils/grammar/semantics/evaluators/arithmetic-eval.h
@@ -0,0 +1,46 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_ARITHMETIC_EVAL_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_ARITHMETIC_EVAL_H_ + +#include "utils/base/arena.h" +#include "utils/grammar/semantics/eval-context.h" +#include "utils/grammar/semantics/evaluator.h" +#include "utils/grammar/semantics/expression_generated.h" +#include "utils/grammar/semantics/value.h" + +namespace libtextclassifier3::grammar { + +// Evaluates an arithmetic expression. +// Expects zero or more arguments and produces either sum, product, minimum or +// maximum of its arguments. If no arguments are specified, each operator +// returns its identity value. +class ArithmeticExpressionEvaluator : public SemanticExpressionEvaluator { + public: + explicit ArithmeticExpressionEvaluator( + const SemanticExpressionEvaluator* composer) + : composer_(composer) {} + + StatusOr<const SemanticValue*> Apply(const EvalContext& context, + const SemanticExpression* expression, + UnsafeArena* arena) const override; + + private: + const SemanticExpressionEvaluator* composer_; +}; + +} // namespace libtextclassifier3::grammar +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_ARITHMETIC_EVAL_H_ diff --git a/utils/grammar/semantics/evaluators/compose-eval.cc b/utils/grammar/semantics/evaluators/compose-eval.cc new file mode 100644 index 0000000..139ec80 --- /dev/null +++ b/utils/grammar/semantics/evaluators/compose-eval.cc
@@ -0,0 +1,182 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "utils/grammar/semantics/evaluators/compose-eval.h" + +#include "utils/base/status_macros.h" +#include "utils/strings/stringpiece.h" + +namespace libtextclassifier3::grammar { +namespace { + +// Tries setting a singular field. +template <typename T> +Status TrySetField(const reflection::Field* field, const SemanticValue* value, + MutableFlatbuffer* result) { + if (!result->Set<T>(field, value->Value<T>())) { + return Status(StatusCode::INVALID_ARGUMENT, "Could not set field."); + } + return Status::OK; +} + +template <> +Status TrySetField<flatbuffers::Table>(const reflection::Field* field, + const SemanticValue* value, + MutableFlatbuffer* result) { + if (!result->Mutable(field)->MergeFrom(value->Table())) { + return Status(StatusCode::INVALID_ARGUMENT, + "Could not set sub-field in result."); + } + return Status::OK; +} + +// Tries adding a value to a repeated field. +template <typename T> +Status TryAddField(const reflection::Field* field, const SemanticValue* value, + MutableFlatbuffer* result) { + if (!result->Repeated(field)->Add(value->Value<T>())) { + return Status(StatusCode::INVALID_ARGUMENT, "Could not add field."); + } + return Status::OK; +} + +template <> +Status TryAddField<flatbuffers::Table>(const reflection::Field* field, + const SemanticValue* value, + MutableFlatbuffer* result) { + if (!result->Repeated(field)->Add()->MergeFrom(value->Table())) { + return Status(StatusCode::INVALID_ARGUMENT, + "Could not add message to repeated field."); + } + return Status::OK; +} + +// Tries adding or setting a value for a field. +template <typename T> +Status TrySetOrAddValue(const FlatbufferFieldPath* field_path, + const SemanticValue* value, MutableFlatbuffer* result) { + MutableFlatbuffer* parent; + const reflection::Field* field; + if (!result->GetFieldWithParent(field_path, &parent, &field)) { + return Status(StatusCode::INVALID_ARGUMENT, "Could not get field."); + } + if (field->type()->base_type() == reflection::Vector) { + return TryAddField<T>(field, value, parent); + } else { + return TrySetField<T>(field, value, parent); + } +} + +} // namespace + +StatusOr<const SemanticValue*> ComposeEvaluator::Apply( + const EvalContext& context, const SemanticExpression* expression, + UnsafeArena* arena) const { + const ComposeExpression* compose_expression = + expression->expression_as_ComposeExpression(); + std::unique_ptr<MutableFlatbuffer> result = + semantic_value_builder_.NewTable(compose_expression->type()); + + if (result == nullptr) { + return Status(StatusCode::INVALID_ARGUMENT, "Invalid result type."); + } + + // Evaluate and set fields. + if (compose_expression->fields() != nullptr) { + for (const ComposeExpression_::Field* field : + *compose_expression->fields()) { + // Evaluate argument. + TC3_ASSIGN_OR_RETURN(const SemanticValue* value, + composer_->Apply(context, field->value(), arena)); + if (value == nullptr) { + continue; + } + + switch (value->base_type()) { + case reflection::BaseType::Bool: { + TC3_RETURN_IF_ERROR( + TrySetOrAddValue<bool>(field->path(), value, result.get())); + break; + } + case reflection::BaseType::Byte: { + TC3_RETURN_IF_ERROR( + TrySetOrAddValue<int8>(field->path(), value, result.get())); + break; + } + case reflection::BaseType::UByte: { + TC3_RETURN_IF_ERROR( + TrySetOrAddValue<uint8>(field->path(), value, result.get())); + break; + } + case reflection::BaseType::Short: { + TC3_RETURN_IF_ERROR( + TrySetOrAddValue<int16>(field->path(), value, result.get())); + break; + } + case reflection::BaseType::UShort: { + TC3_RETURN_IF_ERROR( + TrySetOrAddValue<uint16>(field->path(), value, result.get())); + break; + } + case reflection::BaseType::Int: { + TC3_RETURN_IF_ERROR( + TrySetOrAddValue<int32>(field->path(), value, result.get())); + break; + } + case reflection::BaseType::UInt: { + TC3_RETURN_IF_ERROR( + TrySetOrAddValue<uint32>(field->path(), value, result.get())); + break; + } + case reflection::BaseType::Long: { + TC3_RETURN_IF_ERROR( + TrySetOrAddValue<int64>(field->path(), value, result.get())); + break; + } + case reflection::BaseType::ULong: { + TC3_RETURN_IF_ERROR( + TrySetOrAddValue<uint64>(field->path(), value, result.get())); + break; + } + case reflection::BaseType::Float: { + TC3_RETURN_IF_ERROR( + TrySetOrAddValue<float>(field->path(), value, result.get())); + break; + } + case reflection::BaseType::Double: { + TC3_RETURN_IF_ERROR( + TrySetOrAddValue<double>(field->path(), value, result.get())); + break; + } + case reflection::BaseType::String: { + TC3_RETURN_IF_ERROR(TrySetOrAddValue<StringPiece>( + field->path(), value, result.get())); + break; + } + case reflection::BaseType::Obj: { + TC3_RETURN_IF_ERROR(TrySetOrAddValue<flatbuffers::Table>( + field->path(), value, result.get())); + break; + } + default: + return Status(StatusCode::INVALID_ARGUMENT, "Unhandled type."); + } + } + } + + return SemanticValue::Create<const MutableFlatbuffer*>(result.get(), arena); +} + +} // namespace libtextclassifier3::grammar diff --git a/utils/grammar/semantics/evaluators/compose-eval.h b/utils/grammar/semantics/evaluators/compose-eval.h new file mode 100644 index 0000000..50e7d25 --- /dev/null +++ b/utils/grammar/semantics/evaluators/compose-eval.h
@@ -0,0 +1,46 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_COMPOSE_EVAL_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_COMPOSE_EVAL_H_ + +#include "utils/base/arena.h" +#include "utils/flatbuffers/mutable.h" +#include "utils/grammar/semantics/eval-context.h" +#include "utils/grammar/semantics/evaluator.h" +#include "utils/grammar/semantics/expression_generated.h" +#include "utils/grammar/semantics/value.h" + +namespace libtextclassifier3::grammar { + +// Combines arguments to a result type. +class ComposeEvaluator : public SemanticExpressionEvaluator { + public: + explicit ComposeEvaluator(const SemanticExpressionEvaluator* composer, + const reflection::Schema* semantic_values_schema) + : composer_(composer), semantic_value_builder_(semantic_values_schema) {} + + StatusOr<const SemanticValue*> Apply(const EvalContext& context, + const SemanticExpression* expression, + UnsafeArena* arena) const override; + + private: + const SemanticExpressionEvaluator* composer_; + const MutableFlatbufferBuilder semantic_value_builder_; +}; + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_COMPOSE_EVAL_H_ diff --git a/utils/grammar/semantics/evaluators/const-eval.h b/utils/grammar/semantics/evaluators/const-eval.h new file mode 100644 index 0000000..e3f7ecf --- /dev/null +++ b/utils/grammar/semantics/evaluators/const-eval.h
@@ -0,0 +1,67 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONST_EVAL_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONST_EVAL_H_ + +#include "utils/base/arena.h" +#include "utils/grammar/semantics/eval-context.h" +#include "utils/grammar/semantics/evaluator.h" +#include "utils/grammar/semantics/expression_generated.h" +#include "utils/grammar/semantics/value.h" + +namespace libtextclassifier3::grammar { + +// Returns a constant value of a given type. +class ConstEvaluator : public SemanticExpressionEvaluator { + public: + explicit ConstEvaluator(const reflection::Schema* semantic_values_schema) + : semantic_values_schema_(semantic_values_schema) {} + + StatusOr<const SemanticValue*> Apply(const EvalContext&, + const SemanticExpression* expression, + UnsafeArena* arena) const override { + TC3_DCHECK_EQ(expression->expression_type(), + SemanticExpression_::Expression_ConstValueExpression); + const ConstValueExpression* const_value_expression = + expression->expression_as_ConstValueExpression(); + const reflection::BaseType base_type = + static_cast<reflection::BaseType>(const_value_expression->base_type()); + const StringPiece data = StringPiece( + reinterpret_cast<const char*>(const_value_expression->value()->data()), + const_value_expression->value()->size()); + + if (base_type == reflection::BaseType::Obj) { + // Resolve the object type. + const int type_id = const_value_expression->type(); + if (type_id < 0 || + type_id >= semantic_values_schema_->objects()->size()) { + return Status(StatusCode::INVALID_ARGUMENT, "Invalid type."); + } + return SemanticValue::Create(semantic_values_schema_->objects()->Get( + const_value_expression->type()), + data, arena); + } else { + return SemanticValue::Create(base_type, data, arena); + } + } + + private: + const reflection::Schema* semantic_values_schema_; +}; + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONST_EVAL_H_ diff --git a/utils/grammar/semantics/evaluators/constituent-eval.h b/utils/grammar/semantics/evaluators/constituent-eval.h new file mode 100644 index 0000000..ca0e09b --- /dev/null +++ b/utils/grammar/semantics/evaluators/constituent-eval.h
@@ -0,0 +1,50 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONSTITUENT_EVAL_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONSTITUENT_EVAL_H_ + +#include "utils/base/arena.h" +#include "utils/grammar/semantics/eval-context.h" +#include "utils/grammar/semantics/evaluator.h" +#include "utils/grammar/semantics/expression_generated.h" +#include "utils/grammar/semantics/value.h" + +namespace libtextclassifier3::grammar { + +// Returns the semantic value of an evaluated constituent. +class ConstituentEvaluator : public SemanticExpressionEvaluator { + public: + StatusOr<const SemanticValue*> Apply(const EvalContext& context, + const SemanticExpression* expression, + UnsafeArena*) const override { + TC3_DCHECK_EQ(expression->expression_type(), + SemanticExpression_::Expression_ConstituentExpression); + const ConstituentExpression* constituent_expression = + expression->expression_as_ConstituentExpression(); + const auto constituent_it = + context.rule_constituents.find(constituent_expression->id()); + if (constituent_it != context.rule_constituents.end()) { + return constituent_it->second; + } + // The constituent was not present in the rule parse tree, return a + // null value for it. + return nullptr; + } +}; + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_CONSTITUENT_EVAL_H_ diff --git a/utils/grammar/semantics/evaluators/merge-values-eval.cc b/utils/grammar/semantics/evaluators/merge-values-eval.cc new file mode 100644 index 0000000..9415125 --- /dev/null +++ b/utils/grammar/semantics/evaluators/merge-values-eval.cc
@@ -0,0 +1,48 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "utils/grammar/semantics/evaluators/merge-values-eval.h" + +namespace libtextclassifier3::grammar { + +StatusOr<const SemanticValue*> MergeValuesEvaluator::Apply( + const EvalContext& context, const SemanticExpression* expression, + UnsafeArena* arena) const { + const MergeValueExpression* merge_value_expression = + expression->expression_as_MergeValueExpression(); + std::unique_ptr<MutableFlatbuffer> result = + semantic_value_builder_.NewTable(merge_value_expression->type()); + + if (result == nullptr) { + return Status(StatusCode::INVALID_ARGUMENT, "Invalid result type."); + } + + for (const SemanticExpression* semantic_expression : + *merge_value_expression->values()) { + TC3_ASSIGN_OR_RETURN(const SemanticValue* value, + composer_->Apply(context, semantic_expression, arena)); + if (value == nullptr) { + continue; + } + if ((value->type() != result->type()) || + !result->MergeFrom(value->Table())) { + return Status(StatusCode::INVALID_ARGUMENT, + "Could not merge the results."); + } + } + return SemanticValue::Create<const MutableFlatbuffer*>(result.get(), arena); +} + +} // namespace libtextclassifier3::grammar diff --git a/utils/grammar/semantics/evaluators/merge-values-eval.h b/utils/grammar/semantics/evaluators/merge-values-eval.h new file mode 100644 index 0000000..dac42a6 --- /dev/null +++ b/utils/grammar/semantics/evaluators/merge-values-eval.h
@@ -0,0 +1,49 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_MERGE_VALUES_EVAL_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_MERGE_VALUES_EVAL_H_ + +#include "utils/base/arena.h" +#include "utils/base/status_macros.h" +#include "utils/flatbuffers/mutable.h" +#include "utils/grammar/semantics/eval-context.h" +#include "utils/grammar/semantics/evaluator.h" +#include "utils/grammar/semantics/expression_generated.h" +#include "utils/grammar/semantics/value.h" + +namespace libtextclassifier3::grammar { + +// Evaluate the “merge” semantic function expression. +// Conceptually, the way this merge evaluator works is that each of the +// arguments (semantic value) is merged into a return type semantic value. +class MergeValuesEvaluator : public SemanticExpressionEvaluator { + public: + explicit MergeValuesEvaluator( + const SemanticExpressionEvaluator* composer, + const reflection::Schema* semantic_values_schema) + : composer_(composer), semantic_value_builder_(semantic_values_schema) {} + + StatusOr<const SemanticValue*> Apply(const EvalContext& context, + const SemanticExpression* expression, + UnsafeArena* arena) const override; + + private: + const SemanticExpressionEvaluator* composer_; + const MutableFlatbufferBuilder semantic_value_builder_; +}; +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_MERGE_VALUES_EVAL_H_ diff --git a/utils/grammar/semantics/evaluators/parse-number-eval.h b/utils/grammar/semantics/evaluators/parse-number-eval.h new file mode 100644 index 0000000..10b2685 --- /dev/null +++ b/utils/grammar/semantics/evaluators/parse-number-eval.h
@@ -0,0 +1,109 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_PARSE_NUMBER_EVAL_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_PARSE_NUMBER_EVAL_H_ + +#include <string> + +#include "utils/base/arena.h" +#include "utils/base/statusor.h" +#include "utils/grammar/semantics/eval-context.h" +#include "utils/grammar/semantics/evaluator.h" +#include "utils/grammar/semantics/expression_generated.h" +#include "utils/grammar/semantics/value.h" +#include "utils/strings/numbers.h" + +namespace libtextclassifier3::grammar { + +// Parses a string as a number. +class ParseNumberEvaluator : public SemanticExpressionEvaluator { + public: + explicit ParseNumberEvaluator(const SemanticExpressionEvaluator* composer) + : composer_(composer) {} + + StatusOr<const SemanticValue*> Apply(const EvalContext& context, + const SemanticExpression* expression, + UnsafeArena* arena) const override { + TC3_DCHECK_EQ(expression->expression_type(), + SemanticExpression_::Expression_ParseNumberExpression); + const ParseNumberExpression* parse_number_expression = + expression->expression_as_ParseNumberExpression(); + + // Evaluate argument. + TC3_ASSIGN_OR_RETURN( + const SemanticValue* value, + composer_->Apply(context, parse_number_expression->value(), arena)); + if (value == nullptr) { + return nullptr; + } + if (!value->Has<StringPiece>()) { + return Status(StatusCode::INVALID_ARGUMENT, + "Argument didn't evaluate as a string value."); + } + const std::string data = value->Value<std::string>(); + + // Parse the string data as a number. + const reflection::BaseType type = + static_cast<reflection::BaseType>(parse_number_expression->base_type()); + if (flatbuffers::IsLong(type)) { + TC3_ASSIGN_OR_RETURN(const int64 value, TryParse<int64>(data)); + return SemanticValue::Create(type, value, arena); + } else if (flatbuffers::IsInteger(type)) { + TC3_ASSIGN_OR_RETURN(const int32 value, TryParse<int32>(data)); + return SemanticValue::Create(type, value, arena); + } else if (flatbuffers::IsFloat(type)) { + TC3_ASSIGN_OR_RETURN(const double value, TryParse<double>(data)); + return SemanticValue::Create(type, value, arena); + } else { + return Status(StatusCode::INVALID_ARGUMENT, + "Unsupported type: " + std::to_string(type)); + } + } + + private: + template <typename T> + bool Parse(const std::string& data, T* value) const; + + template <> + bool Parse(const std::string& data, int32* value) const { + return ParseInt32(data.data(), value); + } + + template <> + bool Parse(const std::string& data, int64* value) const { + return ParseInt64(data.data(), value); + } + + template <> + bool Parse(const std::string& data, double* value) const { + return ParseDouble(data.data(), value); + } + + template <typename T> + StatusOr<T> TryParse(const std::string& data) const { + T result; + if (!Parse<T>(data, &result)) { + return Status(StatusCode::INVALID_ARGUMENT, "Could not parse value."); + } + return result; + } + + const SemanticExpressionEvaluator* composer_; +}; + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_PARSE_NUMBER_EVAL_H_ diff --git a/utils/grammar/semantics/evaluators/span-eval.h b/utils/grammar/semantics/evaluators/span-eval.h new file mode 100644 index 0000000..7539592 --- /dev/null +++ b/utils/grammar/semantics/evaluators/span-eval.h
@@ -0,0 +1,44 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_SPAN_EVAL_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_SPAN_EVAL_H_ + +#include "annotator/types.h" +#include "utils/base/arena.h" +#include "utils/base/statusor.h" +#include "utils/grammar/semantics/eval-context.h" +#include "utils/grammar/semantics/evaluator.h" +#include "utils/grammar/semantics/expression_generated.h" +#include "utils/grammar/semantics/value.h" + +namespace libtextclassifier3::grammar { + +// Returns a value lifted from a parse tree. +class SpanAsStringEvaluator : public SemanticExpressionEvaluator { + public: + StatusOr<const SemanticValue*> Apply(const EvalContext& context, + const SemanticExpression* expression, + UnsafeArena* arena) const override { + TC3_DCHECK_EQ(expression->expression_type(), + SemanticExpression_::Expression_SpanAsStringExpression); + return SemanticValue::Create( + context.text_context->Span(context.parse_tree->codepoint_span), arena); + } +}; + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_EVALUATORS_SPAN_EVAL_H_ diff --git a/utils/grammar/next/semantics/expression.fbs b/utils/grammar/semantics/expression.fbs similarity index 71% rename from utils/grammar/next/semantics/expression.fbs rename to utils/grammar/semantics/expression.fbs index 0f36df4..42bb0d4 100755 --- a/utils/grammar/next/semantics/expression.fbs +++ b/utils/grammar/semantics/expression.fbs
@@ -15,7 +15,7 @@ include "utils/flatbuffers/flatbuffers.fbs"; -namespace libtextclassifier3.grammar.next.SemanticExpression_; +namespace libtextclassifier3.grammar.SemanticExpression_; union Expression { ConstValueExpression, ConstituentExpression, @@ -23,16 +23,17 @@ SpanAsStringExpression, ParseNumberExpression, MergeValueExpression, + ArithmeticExpression, } // A semantic expression. -namespace libtextclassifier3.grammar.next; +namespace libtextclassifier3.grammar; table SemanticExpression { expression:SemanticExpression_.Expression; } // A constant flatbuffer value. -namespace libtextclassifier3.grammar.next; +namespace libtextclassifier3.grammar; table ConstValueExpression { // The base type of the value. base_type:int; @@ -46,14 +47,14 @@ } // The value of a rule constituent. -namespace libtextclassifier3.grammar.next; +namespace libtextclassifier3.grammar; table ConstituentExpression { // The id of the constituent. id:ushort; } // The fields to set. -namespace libtextclassifier3.grammar.next.ComposeExpression_; +namespace libtextclassifier3.grammar.ComposeExpression_; table Field { // The field to set. path:libtextclassifier3.FlatbufferFieldPath; @@ -64,7 +65,7 @@ // A combination: Compose a result from arguments. // https://mitpress.mit.edu/sites/default/files/sicp/full-text/book/book-Z-H-4.html#%_toc_%_sec_1.1.1 -namespace libtextclassifier3.grammar.next; +namespace libtextclassifier3.grammar; table ComposeExpression { // The id of the type of the result. type:int; @@ -73,12 +74,12 @@ } // Lifts a span as a value. -namespace libtextclassifier3.grammar.next; +namespace libtextclassifier3.grammar; table SpanAsStringExpression { } // Parses a string as a number. -namespace libtextclassifier3.grammar.next; +namespace libtextclassifier3.grammar; table ParseNumberExpression { // The base type of the value. base_type:int; @@ -87,7 +88,7 @@ } // Merge the semantic expressions. -namespace libtextclassifier3.grammar.next; +namespace libtextclassifier3.grammar; table MergeValueExpression { // The id of the type of the result. type:int; @@ -95,3 +96,23 @@ values:[SemanticExpression]; } +// The operator of the arithmetic expression. +namespace libtextclassifier3.grammar.ArithmeticExpression_; +enum Operator : int { + NO_OP = 0, + OP_ADD = 1, + OP_MUL = 2, + OP_MAX = 3, + OP_MIN = 4, +} + +// Simple arithmetic expression. +namespace libtextclassifier3.grammar; +table ArithmeticExpression { + // The base type of the operation. + base_type:int; + + op:ArithmeticExpression_.Operator; + values:[SemanticExpression]; +} + diff --git a/utils/grammar/semantics/value.h b/utils/grammar/semantics/value.h new file mode 100644 index 0000000..f0b5b19 --- /dev/null +++ b/utils/grammar/semantics/value.h
@@ -0,0 +1,217 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_VALUE_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_VALUE_H_ + +#include "utils/base/arena.h" +#include "utils/base/logging.h" +#include "utils/flatbuffers/mutable.h" +#include "utils/flatbuffers/reflection.h" +#include "utils/strings/stringpiece.h" +#include "utils/utf8/unicodetext.h" +#include "flatbuffers/base.h" +#include "flatbuffers/reflection.h" + +namespace libtextclassifier3::grammar { + +// A semantic value as a typed, arena-allocated flatbuffer. +// This denotes the possible results of the evaluation of a semantic expression. +class SemanticValue { + public: + // Creates an arena allocated semantic value. + template <typename T> + static const SemanticValue* Create(const T value, UnsafeArena* arena) { + static_assert(!std::is_pointer<T>() && std::is_scalar<T>()); + if (char* buffer = reinterpret_cast<char*>( + arena->AllocAligned(sizeof(T), alignof(T)))) { + flatbuffers::WriteScalar<T>(buffer, value); + return arena->AllocAndInit<SemanticValue>( + libtextclassifier3::flatbuffers_base_type<T>::value, + StringPiece(buffer, sizeof(T))); + } + return nullptr; + } + + template <> + const SemanticValue* Create(const StringPiece value, UnsafeArena* arena) { + return arena->AllocAndInit<SemanticValue>(reflection::BaseType::String, + value); + } + + template <> + const SemanticValue* Create(const UnicodeText value, UnsafeArena* arena) { + return arena->AllocAndInit<SemanticValue>( + reflection::BaseType::String, + StringPiece(value.data(), value.size_bytes())); + } + + template <> + const SemanticValue* Create(const MutableFlatbuffer* value, + UnsafeArena* arena) { + const std::string buffer = value->Serialize(); + return Create( + value->type(), + StringPiece(arena->Memdup(buffer.data(), buffer.size()), buffer.size()), + arena); + } + + static const SemanticValue* Create(const reflection::Object* type, + const StringPiece data, + UnsafeArena* arena) { + return arena->AllocAndInit<SemanticValue>(type, data); + } + + static const SemanticValue* Create(const reflection::BaseType base_type, + const StringPiece data, + UnsafeArena* arena) { + return arena->AllocAndInit<SemanticValue>(base_type, data); + } + + template <typename T> + static const SemanticValue* Create(const reflection::BaseType base_type, + const T value, UnsafeArena* arena) { + switch (base_type) { + case reflection::BaseType::Bool: + return Create( + static_cast< + flatbuffers_cpp_type<reflection::BaseType::Bool>::value>(value), + arena); + case reflection::BaseType::Byte: + return Create( + static_cast< + flatbuffers_cpp_type<reflection::BaseType::Byte>::value>(value), + arena); + case reflection::BaseType::UByte: + return Create( + static_cast< + flatbuffers_cpp_type<reflection::BaseType::UByte>::value>( + value), + arena); + case reflection::BaseType::Short: + return Create( + static_cast< + flatbuffers_cpp_type<reflection::BaseType::Short>::value>( + value), + arena); + case reflection::BaseType::UShort: + return Create( + static_cast< + flatbuffers_cpp_type<reflection::BaseType::UShort>::value>( + value), + arena); + case reflection::BaseType::Int: + return Create( + static_cast<flatbuffers_cpp_type<reflection::BaseType::Int>::value>( + value), + arena); + case reflection::BaseType::UInt: + return Create( + static_cast< + flatbuffers_cpp_type<reflection::BaseType::UInt>::value>(value), + arena); + case reflection::BaseType::Long: + return Create( + static_cast< + flatbuffers_cpp_type<reflection::BaseType::Long>::value>(value), + arena); + case reflection::BaseType::ULong: + return Create( + static_cast< + flatbuffers_cpp_type<reflection::BaseType::ULong>::value>( + value), + arena); + case reflection::BaseType::Float: + return Create( + static_cast< + flatbuffers_cpp_type<reflection::BaseType::Float>::value>( + value), + arena); + case reflection::BaseType::Double: + return Create( + static_cast< + flatbuffers_cpp_type<reflection::BaseType::Double>::value>( + value), + arena); + default: { + TC3_LOG(ERROR) << "Unhandled type: " << base_type; + return nullptr; + } + } + } + + explicit SemanticValue(const reflection::BaseType base_type, + const StringPiece data) + : base_type_(base_type), type_(nullptr), data_(data) {} + explicit SemanticValue(const reflection::Object* type, const StringPiece data) + : base_type_(reflection::BaseType::Obj), type_(type), data_(data) {} + + template <typename T> + bool Has() const { + return base_type_ == libtextclassifier3::flatbuffers_base_type<T>::value; + } + + template <> + bool Has<flatbuffers::Table>() const { + return base_type_ == reflection::BaseType::Obj; + } + + template <typename T = flatbuffers::Table> + const T* Table() const { + TC3_CHECK(Has<flatbuffers::Table>()); + return flatbuffers::GetRoot<T>( + reinterpret_cast<const unsigned char*>(data_.data())); + } + + template <typename T> + const T Value() const { + TC3_CHECK(Has<T>()); + return flatbuffers::ReadScalar<T>(data_.data()); + } + + template <> + const StringPiece Value<StringPiece>() const { + TC3_CHECK(Has<StringPiece>()); + return data_; + } + + template <> + const std::string Value<std::string>() const { + TC3_CHECK(Has<StringPiece>()); + return data_.ToString(); + } + + template <> + const UnicodeText Value<UnicodeText>() const { + TC3_CHECK(Has<StringPiece>()); + return UTF8ToUnicodeText(data_, /*do_copy=*/false); + } + + const reflection::BaseType base_type() const { return base_type_; } + const reflection::Object* type() const { return type_; } + + private: + // The base type. + const reflection::BaseType base_type_; + + // The object type of the value. + const reflection::Object* type_; + + StringPiece data_; +}; + +} // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_VALUE_H_ diff --git a/utils/grammar/text-context.h b/utils/grammar/text-context.h new file mode 100644 index 0000000..6fc0024 --- /dev/null +++ b/utils/grammar/text-context.h
@@ -0,0 +1,56 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TEXT_CONTEXT_H_ +#define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TEXT_CONTEXT_H_ + +#include <vector> + +#include "annotator/types.h" +#include "utils/i18n/locale.h" +#include "utils/utf8/unicodetext.h" + +namespace libtextclassifier3::grammar { + +// Input to the parser. +struct TextContext { + // Returns a view on a span of the text. + const UnicodeText Span(const CodepointSpan& span) const { + return text.Substring(codepoints[span.first], codepoints[span.second], + /*do_copy=*/false); + } + + // The input text. + UnicodeText text; + + // Pre-enumerated codepoints for fast substring extraction. + std::vector<UnicodeText::const_iterator> codepoints; + + // The tokenized input text. + std::vector<Token> tokens; + + // Locales of the input text. + std::vector<Locale> locales; + + // Text annotations. + std::vector<AnnotatedSpan> annotations; + + // The span of tokens to consider. + TokenSpan context_span; +}; + +}; // namespace libtextclassifier3::grammar + +#endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TEXT_CONTEXT_H_ diff --git a/utils/grammar/types.h b/utils/grammar/types.h index 7786d31..0255344 100644 --- a/utils/grammar/types.h +++ b/utils/grammar/types.h
@@ -13,8 +13,6 @@ // limitations under the License. // -#pragma GCC diagnostic ignored "-Wc++17-extensions" - // Common definitions used in the grammar system. #ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_TYPES_H_
diff --git a/utils/grammar/utils/ir.cc b/utils/grammar/utils/ir.cc index 545dc72..12558fc 100644 --- a/utils/grammar/utils/ir.cc +++ b/utils/grammar/utils/ir.cc
@@ -191,15 +191,6 @@ continue; } - // If either callback is a filter, we can't share as we must always run - // both filters. - if ((lhs.callback.id != kNoCallback && - filters_.find(lhs.callback.id) != filters_.end()) || - (candidate->callback.id != kNoCallback && - filters_.find(candidate->callback.id) != filters_.end())) { - continue; - } - // If the nonterminal is already defined, it must match for sharing. if (lhs.nonterminal != kUnassignedNonterm && lhs.nonterminal != candidate->nonterminal) { @@ -405,13 +396,6 @@ void Ir::Serialize(const bool include_debug_information, RulesSetT* output) const { - // Set callback information. - for (const CallbackId filter_callback_id : filters_) { - output->callback.push_back(RulesSet_::CallbackEntry( - filter_callback_id, RulesSet_::Callback(/*is_filter=*/true))); - } - SortStructsForBinarySearchLookup(&output->callback); - // Add information about predefined nonterminal classes. output->nonterminals.reset(new RulesSet_::NonterminalsT); output->nonterminals->start_nt = GetNonterminalForName(kStartNonterm); diff --git a/utils/grammar/utils/ir.h b/utils/grammar/utils/ir.h index db53b4c..9c1b37f 100644 --- a/utils/grammar/utils/ir.h +++ b/utils/grammar/utils/ir.h
@@ -95,9 +95,8 @@ std::unordered_map<TwoNonterms, LhsSet, BinaryRuleHasher> binary_rules; }; - explicit Ir(const std::unordered_set<CallbackId>& filters = {}, - const int num_shards = 1) - : num_nonterminals_(0), filters_(filters), shards_(num_shards) {} + explicit Ir(const int num_shards = 1) + : num_nonterminals_(0), shards_(num_shards) {} // Adds a new non-terminal. Nonterm AddNonterminal(const std::string& name = "") { @@ -224,9 +223,6 @@ Nonterm num_nonterminals_; std::unordered_set<Nonterm> nonshareable_; - // The set of callbacks that should be treated as filters. - std::unordered_set<CallbackId> filters_; - // The sharded rules. std::vector<RulesShard> shards_; diff --git a/utils/grammar/utils/rules.cc b/utils/grammar/utils/rules.cc index 2209100..044da4d 100644 --- a/utils/grammar/utils/rules.cc +++ b/utils/grammar/utils/rules.cc
@@ -160,9 +160,16 @@ void Rules::AddAlias(const std::string& nonterminal_name, const std::string& alias) { +#ifndef TC3_USE_CXX14 + TC3_CHECK_EQ(nonterminal_alias_.insert_or_assign(alias, nonterminal_name) + .first->second, + nonterminal_name) + << "Cannot redefine alias: " << alias; +#else nonterminal_alias_[alias] = nonterminal_name; TC3_CHECK_EQ(nonterminal_alias_[alias], nonterminal_name) << "Cannot redefine alias: " << alias; +#endif } // Defines a nonterminal for an externally provided annotation. @@ -301,7 +308,7 @@ const int8 max_whitespace_gap, const bool case_sensitive, const int shard) { // Resolve anchors and fillers. - const std::vector<RhsElement> optimized_rhs = OptimizeRhs(rhs); + const std::vector optimized_rhs = OptimizeRhs(rhs); std::vector<int> optional_element_indices; TC3_CHECK_LT(optional_element_indices.size(), optimized_rhs.size()) @@ -406,7 +413,7 @@ } Ir Rules::Finalize(const std::set<std::string>& predefined_nonterminals) const { - Ir rules(filters_, num_shards_); + Ir rules(num_shards_); std::unordered_map<int, Nonterm> nonterminal_ids; // Pending rules to process. @@ -422,7 +429,7 @@ } // Assign (unmergeable) Nonterm values to any nonterminals that have - // multiple rules or that have a filter callback on some rule. + // multiple rules. for (int i = 0; i < nonterminals_.size(); i++) { const NontermInfo& nonterminal = nonterminals_[i]; @@ -435,15 +442,8 @@ (nonterminal.from_annotation || nonterminal.rules.size() > 1 || !nonterminal.regex_rules.empty()); for (const int rule_index : nonterminal.rules) { - const Rule& rule = rules_[rule_index]; - // Schedule rule. scheduled_rules.insert({i, rule_index}); - - if (rule.callback != kNoCallback && - filters_.find(rule.callback) != filters_.end()) { - unmergeable = true; - } } if (unmergeable) { diff --git a/utils/grammar/utils/rules.h b/utils/grammar/utils/rules.h index 2360863..96db302 100644 --- a/utils/grammar/utils/rules.h +++ b/utils/grammar/utils/rules.h
@@ -33,19 +33,15 @@ // All rules for a grammar will be collected in a rules object. // // Rules r; -// CallbackId date_output_callback = 1; -// CallbackId day_filter_callback = 2; r.DefineFilter(day_filter_callback); -// CallbackId year_filter_callback = 3; r.DefineFilter(year_filter_callback); -// r.Add("<date>", {"<monthname>", "<day>", <year>"}, -// date_output_callback); +// r.Add("<date>", {"<monthname>", "<day>", <year>"}); // r.Add("<monthname>", {"January"}); // ... // r.Add("<monthname>", {"December"}); -// r.Add("<day>", {"<string_of_digits>"}, day_filter_callback); -// r.Add("<year>", {"<string_of_digits>"}, year_filter_callback); +// r.Add("<day>", {"<string_of_digits>"}); +// r.Add("<year>", {"<string_of_digits>"}); // -// The Add() method adds a rule with a given lhs, rhs, and (optionally) -// callback. The rhs is just a list of terminals and nonterminals. Anything +// The Add() method adds a rule with a given lhs, rhs/ +// The rhs is just a list of terminals and nonterminals. Anything // surrounded in angle brackets is considered a nonterminal. A "?" can follow // any element of the RHS, like this: // @@ -54,9 +50,8 @@ // This indicates that the <day> and "," parts of the rhs are optional. // (This is just notational shorthand for adding a bunch of rules.) // -// Once you're done adding rules and callbacks to the Rules object, -// call r.Finalize() on it. This lowers the rule set into an internal -// representation. +// Once you're done adding rules, r.Finalize() lowers the rule set into an +// internal representation. class Rules { public: explicit Rules(const int num_shards = 1) : num_shards_(num_shards) {} @@ -172,9 +167,6 @@ // nonterminal. void AddAlias(const std::string& nonterminal_name, const std::string& alias); - // Defines a new filter id. - void DefineFilter(const CallbackId filter_id) { filters_.insert(filter_id); } - // Lowers the rule set into the intermediate representation. // Treats nonterminals given by the argument `predefined_nonterminals` as // defined externally. This allows to define rules that are dependent on @@ -232,9 +224,6 @@ // Rules. std::vector<Rule> rules_; std::vector<std::string> regex_rules_; - - // Ids of callbacks that should be treated as filters. - std::unordered_set<CallbackId> filters_; }; } // namespace libtextclassifier3::grammar diff --git a/utils/i18n/locale-list.cc b/utils/i18n/locale-list.cc new file mode 100644 index 0000000..f951eec --- /dev/null +++ b/utils/i18n/locale-list.cc
@@ -0,0 +1,43 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "utils/i18n/locale-list.h" + +#include <string> + +namespace libtextclassifier3 { + +LocaleList LocaleList::ParseFrom(const std::string& locale_tags) { + std::vector<StringPiece> split_locales = strings::Split(locale_tags, ','); + std::string reference_locale; + if (!split_locales.empty()) { + // Assigns the first parsed locale to reference_locale. + reference_locale = split_locales[0].ToString(); + } else { + reference_locale = ""; + } + std::vector<Locale> locales; + for (const StringPiece& locale_str : split_locales) { + const Locale locale = Locale::FromBCP47(locale_str.ToString()); + if (!locale.IsValid()) { + TC3_LOG(WARNING) << "Failed to parse the detected_text_language_tag: " + << locale_str.ToString(); + } + locales.push_back(locale); + } + return LocaleList(locales, split_locales, reference_locale); +} + +} // namespace libtextclassifier3 diff --git a/utils/i18n/locale-list.h b/utils/i18n/locale-list.h new file mode 100644 index 0000000..78af19b --- /dev/null +++ b/utils/i18n/locale-list.h
@@ -0,0 +1,54 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_LIST_H_ +#define LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_LIST_H_ + +#include <string> + +#include "utils/i18n/locale.h" +#include "utils/strings/split.h" + +namespace libtextclassifier3 { + +// Parses and hold data about locales (combined by delimiter ','). +class LocaleList { + public: + // Constructs the + // - Collection of locale tag from local_tags + // - Collection of Locale objects from a valid BCP47 tag. (If the tag is + // invalid, an object is created but return false for IsInvalid() call. + // - Assigns the first parsed locale to reference_locale. + static LocaleList ParseFrom(const std::string& locale_tags); + + std::vector<Locale> GetLocales() const { return locales_; } + std::vector<StringPiece> GetLocaleTags() const { return split_locales_; } + std::string GetReferenceLocale() const { return reference_locale_; } + + private: + LocaleList(const std::vector<Locale>& locales, + const std::vector<StringPiece>& split_locales, + const StringPiece& reference_locale) + : locales_(locales), + split_locales_(split_locales), + reference_locale_(reference_locale.ToString()) {} + + const std::vector<Locale> locales_; + const std::vector<StringPiece> split_locales_; + const std::string reference_locale_; +}; +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_LIST_H_ diff --git a/utils/tokenizer.cc b/utils/tokenizer.cc index 2ee9e21..5a4f79a 100644 --- a/utils/tokenizer.cc +++ b/utils/tokenizer.cc
@@ -49,6 +49,10 @@ SortCodepointRanges(internal_tokenizer_codepoint_ranges, &internal_tokenizer_codepoint_ranges_); + if (type_ == TokenizationType_MIXED && split_on_script_change) { + TC3_LOG(ERROR) << "The option `split_on_script_change` is unavailable for " + "the selected tokenizer type (mixed)."; + } } const TokenizationCodepointRangeT* Tokenizer::FindTokenizationRange( @@ -233,15 +237,20 @@ if (!break_iterator) { return false; } + const int context_unicode_size = context_unicode.size_codepoints(); int last_unicode_index = 0; int unicode_index = 0; auto token_begin_it = context_unicode.begin(); while ((unicode_index = break_iterator->Next()) != UniLib::BreakIterator::kDone) { const int token_length = unicode_index - last_unicode_index; + if (token_length + last_unicode_index > context_unicode_size) { + return false; + } auto token_end_it = token_begin_it; std::advance(token_end_it, token_length); + TC3_CHECK(token_end_it <= context_unicode.end()); // Determine if the whole token is whitespace. bool is_whitespace = true; diff --git a/utils/utf8/unilib-common.cc b/utils/utf8/unilib-common.cc index cbc1119..1603dc8 100644 --- a/utils/utf8/unilib-common.cc +++ b/utils/utf8/unilib-common.cc
@@ -406,6 +406,10 @@ 0x275E, 0x276E, 0x276F, 0x2E42, 0x301D, 0x301E, 0x301F, 0xFF02}; constexpr int kNumQuotation = ARRAYSIZE(kQuotation); +// Source: https://unicode-search.net/unicode-namesearch.pl?term=ampersand +constexpr char32 kAmpersand[] = {0x0026, 0xFE60, 0xFF06, 0x1F674, 0x1F675}; +constexpr int kNumAmpersand = ARRAYSIZE(kAmpersand); + #undef ARRAYSIZE static_assert(kNumOpeningBrackets == kNumClosingBrackets, @@ -595,6 +599,10 @@ return GetMatchIndex(kQuotation, kNumQuotation, codepoint) >= 0; } +bool IsAmpersand(char32 codepoint) { + return GetMatchIndex(kAmpersand, kNumAmpersand, codepoint) >= 0; +} + bool IsLatinLetter(char32 codepoint) { return (GetOverlappingRangeIndex( kLatinLettersRangesStart, kLatinLettersRangesEnd, diff --git a/utils/utf8/unilib-common.h b/utils/utf8/unilib-common.h index 1fdfdb3..cc0a9e5 100644 --- a/utils/utf8/unilib-common.h +++ b/utils/utf8/unilib-common.h
@@ -36,6 +36,7 @@ bool IsDot(char32 codepoint); bool IsApostrophe(char32 codepoint); bool IsQuotation(char32 codepoint); +bool IsAmpersand(char32 codepoint); bool IsLatinLetter(char32 codepoint); bool IsArabicLetter(char32 codepoint);
diff --git a/utils/utf8/unilib-icu.cc b/utils/utf8/unilib-icu.cc index a42f78c..ba7f0e1 100644 --- a/utils/utf8/unilib-icu.cc +++ b/utils/utf8/unilib-icu.cc
@@ -19,7 +19,10 @@ #include <utility> #include "utils/base/logging.h" +#include "utils/base/statusor.h" #include "utils/utf8/unilib-common.h" +#include "unicode/unistr.h" +#include "unicode/utext.h" namespace libtextclassifier3 { @@ -113,6 +116,11 @@ return u_getBidiPairedBracket(codepoint); } +StatusOr<int32> UniLibBase::Length(const UnicodeText& text) const { + return icu::UnicodeString::fromUTF8({text.data(), text.size_bytes()}) + .countChar32(); +} + UniLibBase::RegexMatcher::RegexMatcher(icu::RegexPattern* pattern, icu::UnicodeString text) : text_(std::move(text)), diff --git a/utils/utf8/unilib-icu.h b/utils/utf8/unilib-icu.h index 301fe4d..a1bb40f 100644 --- a/utils/utf8/unilib-icu.h +++ b/utils/utf8/unilib-icu.h
@@ -25,6 +25,7 @@ #include <mutex> // NOLINT(build/c++11) #include "utils/base/integral_types.h" +#include "utils/base/statusor.h" #include "utils/utf8/unicodetext.h" #include "unicode/brkiter.h" #include "unicode/errorcode.h" @@ -52,6 +53,8 @@ char32 ToUpper(char32 codepoint) const; char32 GetPairedBracket(char32 codepoint) const; + StatusOr<int32> Length(const UnicodeText& text) const; + // Forward declaration for friend. class RegexPattern;
diff --git a/utils/utf8/unilib.h b/utils/utf8/unilib.h index 0d6d1e5..e33f3be 100644 --- a/utils/utf8/unilib.h +++ b/utils/utf8/unilib.h
@@ -102,6 +102,10 @@ return libtextclassifier3::IsQuotation(codepoint); } + bool IsAmpersand(char32 codepoint) const { + return libtextclassifier3::IsAmpersand(codepoint); + } + bool IsLatinLetter(char32 codepoint) const { return libtextclassifier3::IsLatinLetter(codepoint); } @@ -137,6 +141,31 @@ bool IsLetter(char32 codepoint) const { return libtextclassifier3::IsLetter(codepoint); } + + bool IsValidUtf8(const UnicodeText& text) const { + // Basic check of structural validity of UTF8. + if (!text.is_valid()) { + return false; + } + // In addition to that, we declare that a valid UTF8 is when the number of + // codepoints in the string as measured by ICU is the same as the number of + // codepoints as measured by UnicodeText. Because if we don't do this check, + // the indices might differ, and cause trouble, because the assumption + // throughout the code is that ICU indices and UnicodeText indices are the + // same. + // NOTE: This is not perfect, as this doesn't check the alignment of the + // codepoints, but for the practical purposes should be enough. + const StatusOr<int32> icu_length = Length(text); + if (!icu_length.ok()) { + return false; + } + + if (icu_length.ValueOrDie() != text.size_codepoints()) { + return false; + } + + return true; + } }; } // namespace libtextclassifier3